import datetime
import hashlib
import json
import os
import pickle
import sys
import tempfile
import math
import numpy as np
import requests
from numpy import median
sys.path.append(os.getcwd())
from src.base.clickhouse_kafka_sender import ClickHouseKafkaSender
from src.base.utils import setup_config
from src.base.kafka_handler import (
ExactlyOnceKafkaConsumeHandler,
KafkaMessageFetchException,
)
from src.base.log_config import get_logger
module_name = "data_analysis.detector"
logger = get_logger(module_name)
BUF_SIZE = 65536 # let's read stuff in 64kb chunks!
config = setup_config()
MODEL = config["pipeline"]["data_analysis"]["detector"]["model"]
CHECKSUM = config["pipeline"]["data_analysis"]["detector"]["checksum"]
MODEL_BASE_URL = config["pipeline"]["data_analysis"]["detector"]["base_url"]
THRESHOLD = config["pipeline"]["data_analysis"]["detector"]["threshold"]
CONSUME_TOPIC = config["environment"]["kafka_topics"]["pipeline"][
"inspector_to_detector"
]
[docs]
class WrongChecksum(Exception): # pragma: no cover
"""Raises when model checksum validation fails."""
pass
[docs]
class Detector:
"""Main component of the Data Analysis stage to perform anomaly detection
Processes suspicious batches from the Inspector using configurable ML models to classify
DNS requests as benign or malicious. Downloads and validates models from a remote server,
extracts features from domain names, calculates probability scores, and generates alerts
when malicious requests are detected above the configured threshold.
"""
def __init__(self) -> None:
self.suspicious_batch_id = None
self.key = None
self.messages = []
self.warnings = []
self.begin_timestamp = None
self.end_timestamp = None
self.model_path = os.path.join(
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_model.pickle"
)
self.scaler_path = os.path.join(
tempfile.gettempdir(), f"{MODEL}_{CHECKSUM}_scaler.pickle"
)
self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(CONSUME_TOPIC)
self.model, self.scaler = self._get_model()
# databases
self.suspicious_batch_timestamps = ClickHouseKafkaSender(
"suspicious_batch_timestamps"
)
self.alerts = ClickHouseKafkaSender("alerts")
self.logline_timestamps = ClickHouseKafkaSender("logline_timestamps")
self.fill_levels = ClickHouseKafkaSender("fill_levels")
self.fill_levels.insert(
dict(
timestamp=datetime.datetime.now(),
stage=module_name,
entry_type="total_loglines",
entry_count=0,
)
)
[docs]
def get_and_fill_data(self) -> None:
"""Consumes suspicious batches from Kafka and stores them for analysis.
Fetches suspicious batch data from the Inspector via Kafka and stores it in internal
data structures. If the Detector is already busy processing data, consumption is
skipped with a warning. Updates database entries for monitoring and logging purposes.
"""
if self.messages:
logger.warning(
"Detector is busy: Not consuming new messages. Wait for the Detector to finish the "
"current workload."
)
return
key, data = self.kafka_consume_handler.consume_as_object()
if data.data:
self.suspicious_batch_id = data.batch_id
self.begin_timestamp = data.begin_timestamp
self.end_timestamp = data.end_timestamp
self.messages = data.data
self.key = key
self.suspicious_batch_timestamps.insert(
dict(
suspicious_batch_id=self.suspicious_batch_id,
client_ip=key,
stage=module_name,
status="in_process",
timestamp=datetime.datetime.now(),
is_active=True,
message_count=len(self.messages),
)
)
self.fill_levels.insert(
dict(
timestamp=datetime.datetime.now(),
stage=module_name,
entry_type="total_loglines",
entry_count=len(self.messages),
)
)
if not self.messages:
logger.info(
"Received message:\n"
f" ⤷ Empty data field: No unfiltered data available. Belongs to subnet_id {key}."
)
else:
logger.info(
"Received message:\n"
f" ⤷ Contains data field of {len(self.messages)} message(s). Belongs to subnet_id {key}."
)
def _sha256sum(self, file_path: str) -> str:
"""Calculates SHA256 checksum for model file validation.
Args:
file_path (str): Path to the model file to validate.
Returns:
str: SHA256 hexadecimal digest of the file.
"""
h = hashlib.sha256()
with open(file_path, "rb") as file:
while True:
# Reading is buffered, so we can read smaller chunks.
chunk = file.read(h.block_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def _get_model(self):
"""Downloads and loads ML model and scaler from remote server.
Retrieves the configured model and scaler files from the remote server if not
already present locally. Validates model integrity using SHA256 checksum and
loads the pickled model and scaler objects for inference.
Returns:
tuple: Trained ML model and data scaler objects.
Raises:
WrongChecksum: If model checksum validation fails.
"""
logger.info(f"Get model: {MODEL} with checksum {CHECKSUM}")
if not os.path.isfile(self.model_path):
response = requests.get(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
)
logger.info(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/{MODEL}.pickle&dl=1"
)
response.raise_for_status()
with open(self.model_path, "wb") as f:
f.write(response.content)
if not os.path.isfile(self.scaler_path):
response = requests.get(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
)
logger.info(
f"{MODEL_BASE_URL}/files/?p=%2F{MODEL}/{CHECKSUM}/scaler.pickle&dl=1"
)
response.raise_for_status()
with open(self.scaler_path, "wb") as f:
f.write(response.content)
# Check file sha256
local_checksum = self._sha256sum(self.model_path)
if local_checksum != CHECKSUM:
logger.warning(
f"Checksum {CHECKSUM} SHA256 is not equal with new checksum {local_checksum}!"
)
raise WrongChecksum(
f"Checksum {CHECKSUM} SHA256 is not equal with new checksum {local_checksum}!"
)
with open(self.model_path, "rb") as input_file:
clf = pickle.load(input_file)
with open(self.scaler_path, "rb") as input_file:
scaler = pickle.load(input_file)
return clf, scaler
[docs]
def clear_data(self) -> None:
"""Clears all data from internal data structures.
Resets messages, timestamps, and warnings to prepare the Detector
for processing the next suspicious batch.
"""
self.messages = []
self.begin_timestamp = None
self.end_timestamp = None
self.warnings = []
def _get_features(self, query: str) -> np.ndarray:
"""Extracts feature vector from domain name for ML model inference.
Computes various statistical and linguistic features from the domain name
including label lengths, character frequencies, entropy measures, and
counts of different character types across domain name levels.
Args:
query (str): Domain name string to extract features from.
Returns:
numpy.ndarray: Feature vector ready for ML model prediction.
"""
# Splitting by dots to calculate label length and max length
query = query.strip(".")
label_parts = query.split(".")
levels = {
"fqdn": query,
"secondleveldomain": label_parts[-2] if len(label_parts) >= 2 else "",
"thirdleveldomain": (
".".join(label_parts[:-2]) if len(label_parts) > 2 else ""
),
}
label_length = len(label_parts)
parts = query.split(".")
label_max = len(max(parts, key=str)) if parts else 0
label_average = len(query)
basic_features = np.array(
[label_length, label_max, label_average], dtype=np.float64
)
alc = "abcdefghijklmnopqrstuvwxyz"
query_len = len(query)
freq = np.array(
[query.lower().count(c) / query_len if query_len > 0 else 0.0 for c in alc],
dtype=np.float64,
)
logger.debug("Get full, alpha, special, and numeric count.")
def calculate_counts(level: str) -> np.ndarray:
if not level:
return np.array([0.0, 0.0, 0.0, 0.0], dtype=np.float64)
full_count = len(level) / len(level)
alpha_ratio = sum(c.isalpha() for c in level) / len(level)
numeric_ratio = sum(c.isdigit() for c in level) / len(level)
special_ratio = sum(
not c.isalnum() and not c.isspace() for c in level
) / len(level)
return np.array(
[full_count, alpha_ratio, numeric_ratio, special_ratio],
dtype=np.float64,
)
fqdn_counts = calculate_counts(levels["fqdn"])
third_counts = calculate_counts(levels["thirdleveldomain"])
second_counts = calculate_counts(levels["secondleveldomain"])
level_features = np.hstack([third_counts, second_counts, fqdn_counts])
def calculate_entropy(s: str) -> float:
if len(s) == 0:
return 0.0
probs = [s.count(c) / len(s) for c in dict.fromkeys(s)]
return -sum(p * math.log(p, 2) for p in probs)
logger.debug("Start entropy calculation")
entropy_features = np.array(
[
calculate_entropy(levels["fqdn"]),
calculate_entropy(levels["thirdleveldomain"]),
calculate_entropy(levels["secondleveldomain"]),
],
dtype=np.float64,
)
logger.debug("Entropy features calculated")
all_features = np.concatenate(
[basic_features, freq, level_features, entropy_features]
)
logger.debug("Finished data transformation")
return all_features.reshape(1, -1)
[docs]
def detect(self) -> None: # pragma: no cover
"""Analyzes DNS requests and identifies malicious domains.
Processes each DNS request in the current batch by extracting features,
running ML model prediction, and collecting warnings for requests that
exceed the configured maliciousness threshold.
"""
logger.info("Start detecting malicious requests.")
for message in self.messages:
# TODO predict all messages
# TODO use scalar: self.scaler.transform(self._get_features(message["domain_name"]))
y_pred = self.model.predict_proba(
self._get_features(message["domain_name"])
)
logger.info(f"Prediction: {y_pred}")
if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > THRESHOLD:
logger.info("Append malicious request to warning.")
warning = {
"request": message,
"probability": float(y_pred[0][1]),
"model": MODEL,
"sha256": CHECKSUM,
}
self.warnings.append(warning)
[docs]
def send_warning(self) -> None:
"""Generates and stores alerts for detected malicious requests.
Creates comprehensive alert records from accumulated warnings including
overall risk scores, individual predictions, and metadata. Stores alerts
in the database and updates batch processing status. If no warnings are
present, marks the batch as filtered out.
"""
logger.info("Store alert.")
if len(self.warnings) > 0:
overall_score = median(
[warning["probability"] for warning in self.warnings]
)
alert = {"overall_score": overall_score, "result": self.warnings}
logger.info(f"Add alert: {alert}")
with open(os.path.join(tempfile.gettempdir(), "warnings.json"), "a+") as f:
json.dump(alert, f)
f.write("\n")
self.alerts.insert(
dict(
client_ip=self.key,
alert_timestamp=datetime.datetime.now(),
suspicious_batch_id=self.suspicious_batch_id,
overall_score=overall_score,
domain_names=json.dumps(
[warning["request"] for warning in self.warnings]
),
result=json.dumps(self.warnings),
)
)
self.suspicious_batch_timestamps.insert(
dict(
suspicious_batch_id=self.suspicious_batch_id,
client_ip=self.key,
stage=module_name,
status="finished",
timestamp=datetime.datetime.now(),
is_active=False,
message_count=len(self.messages),
)
)
logline_ids = set()
for message in self.messages:
logline_ids.add(message["logline_id"])
for logline_id in logline_ids:
self.logline_timestamps.insert(
dict(
logline_id=logline_id,
stage=module_name,
status="detected",
timestamp=datetime.datetime.now(),
is_active=False,
)
)
else:
logger.info("No warning produced.")
self.suspicious_batch_timestamps.insert(
dict(
suspicious_batch_id=self.suspicious_batch_id,
client_ip=self.key,
stage=module_name,
status="filtered_out",
timestamp=datetime.datetime.now(),
is_active=False,
message_count=len(self.messages),
)
)
logline_ids = set()
for message in self.messages:
logline_ids.add(message["logline_id"])
for logline_id in logline_ids:
self.logline_timestamps.insert(
dict(
logline_id=logline_id,
stage=module_name,
status="filtered_out",
timestamp=datetime.datetime.now(),
is_active=False,
)
)
self.fill_levels.insert(
dict(
timestamp=datetime.datetime.now(),
stage=module_name,
entry_type="total_loglines",
entry_count=0,
)
)
[docs]
def main(one_iteration: bool = False) -> None: # pragma: no cover
"""Creates and runs the Detector instance in a continuous processing loop.
Initializes the Detector and starts the main processing loop that continuously
fetches suspicious batches from Kafka, performs malicious domain detection,
and generates alerts. Handles various exceptions gracefully and ensures
proper cleanup of data structures.
Args:
one_iteration (bool): For testing purposes - stops loop after one iteration.
Raises:
KeyboardInterrupt: Execution interrupted by user.
"""
logger.info("Starting Detector...")
detector = Detector()
logger.info(f"Detector is running.")
iterations = 0
while True:
if one_iteration and iterations > 0:
break
iterations += 1
try:
logger.debug("Before getting and filling data")
detector.get_and_fill_data()
logger.debug("Inspect Data")
detector.detect()
logger.debug("Send warnings")
detector.send_warning()
except KafkaMessageFetchException as e: # pragma: no cover
logger.debug(e)
except IOError as e:
logger.error(e)
raise e
except ValueError as e:
logger.debug(e)
except KeyboardInterrupt:
logger.info("Closing down Detector...")
break
finally:
detector.clear_data()
if __name__ == "__main__": # pragma: no cover
main()