Source code for train.train

import hashlib
import json
import os
import pickle
import sys
import tempfile
from enum import Enum, unique

import click
import joblib
import numpy as np
import polars as pl
import torch
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler

sys.path.append(os.getcwd())
from src.train.dataset import DatasetLoader
from src.train.model import (
    Pipeline,
)
from src.base.log_config import get_logger
from src.train import CONTEXT_SETTINGS, RESULT_FOLDER, SEED

logger = get_logger("train.train")


[docs] def add_options(options): def _add_options(func): for option in reversed(options): func = option(func) return func return _add_options
[docs] @unique class DatasetEnum(str, Enum): """Available dataset configurations for DGA detection model training""" COMBINE = "combine" CIC = "cic" DGTA = "dgta" DGARCHIVE = "dgarchive"
[docs] @unique class ModelEnum(str, Enum): """Available machine learning algorithms for DGA detection""" RANDOM_FOREST_CLASSIFIER = "rf" XG_BOOST_CLASSIFIER = "xg" GBM_CLASSIFIER = "gbm"
[docs] class DetectorTraining: """Orchestrates end-to-end training of DGA detection models. Manages dataset loading, model selection, training pipeline execution, and model persistence for domain generation algorithm detection. Supports multiple datasets, model types, and handles checksum-based model versioning. """
[docs] def __init__( self, model_name: ModelEnum.RANDOM_FOREST_CLASSIFIER, model_output_path: str = f"./{RESULT_FOLDER}/model", dataset: DatasetEnum = DatasetEnum.COMBINE, data_base_path: str = "./data", max_rows: int = -1, ) -> None: """Initializes training configuration and dataset loading. Sets up model training pipeline with specified algorithm, datasets, and output paths. Handles existing model detection and checksum validation for incremental training workflows. Args: model_name (ModelEnum): ML algorithm type for training. model_output_path (str): Directory path for saving trained models. dataset (DatasetEnum): Dataset configuration for training. data_base_path (str): Base directory containing raw datasets. max_rows (int): Maximum rows per dataset (default: -1 for unlimited). Raises: NotImplementedError: If specified dataset configuration is not supported. """ logger.info("Get DatasetLoader.") self.dataset_loader = DatasetLoader(base_path=data_base_path, max_rows=max_rows) try: model_checksum = self._sha256sum( os.path.join(model_output_path, f"{model_name}.pickle") ) if model_checksum in model_output_path: self.model_checksum = model_checksum self.model_output_path = model_output_path self.model_output_path = self.model_output_path.replace( self.model_checksum, "" ) self.model_output_path = self.model_output_path.replace(model_name, "") self.model_output_path = self.model_output_path.replace("//", "/") except: logger.warning("Model not found, training starts!") self.model_output_path = model_output_path logger.info(self.model_output_path) self.dataset = [] match dataset: case "combine": self.dataset.append(self.dataset_loader.dgta_dataset) self.dataset.append(self.dataset_loader.bambenek_dataset) self.dataset.append(self.dataset_loader.dga_dataset) self.dataset.append(self.dataset_loader.heicloud_dataset) self.dataset = self.dataset + self.dataset_loader.dgarchive_dataset # CIC DNS does work in practice and data is not clean. case "cic": self.dataset.append(self.dataset_loader.cic_dataset) case "dgta": self.dataset.append(self.dataset_loader.dgta_dataset) case "dgarchive": self.dataset.append(self.dataset_loader.dgarchive_data) case _: raise NotImplementedError(f"Dataset not implemented!") logger.info(f"Set up Pipeline.") self.model_name = model_name self.scaler = self._load_scaler() self.model_pipeline = Pipeline( model=self.model_name, datasets=self.dataset, model_output_path=self.model_output_path, scaler=self.scaler, ) self._load_model()
[docs] def explain(self) -> None: """Generates and saves interpretable explanations for the trained model. Extracts decision rules and model interpretations from the trained classifier and saves them to text files for analysis and understanding of model behavior. """ rules = self.model_pipeline.explain( self.model_pipeline.x_val, self.model_pipeline.y_val ) save_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum ) os.makedirs(save_path, exist_ok=True) # Save rules to file with open(os.path.join(save_path, "rules.txt"), "w") as f: f.write("Extracted Rules:\n\n") for i, rule in enumerate(rules, 1): f.write(f"Rule {i}: {rule}\n")
[docs] def test(self) -> None: """Evaluates trained model on all datasets and generates comprehensive reports. Tests model performance across all loaded datasets, computes metrics including classification reports, FDR, and FTTAR. Saves detailed error analysis and misprediction information for model debugging and improvement. """ for X, y, ds in zip( self.model_pipeline.ds_X, self.model_pipeline.ds_y, self.model_pipeline.datasets, ): logger.info("Test validation test.") y_pred = self.model_pipeline.predict(X) y_pred = [round(value) for value in y_pred] y_labels = np.unique(y).tolist() report = classification_report( y, y_pred, output_dict=True, labels=y_labels, zero_division=0 ) logger.info(report) # Get indices of mispredictions mispredicted_indices = [ i for i, (true, pred) in enumerate(zip(y, y_pred)) if true != pred ] mispredictions = [] false_pred = [] # Print or log the mispredicted data points for idx in mispredicted_indices: error = dict() error["y"] = str(y[idx]) error["y_pred"] = str(y_pred[idx]) query = str(ds.data[idx].get_column("query").to_list()[0]) error["query"] = query false_pred.append(query) mispredictions.append(error) # Get matching rows matches = ds.data.filter(pl.col("query").is_in(false_pred)) unique_mispredicton_classes = matches["class"].unique().to_list() model_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum ) os.makedirs(model_path, exist_ok=True) if len(mispredictions) > 0: error_report = dict() error_report["classes"] = unique_mispredicton_classes error_report["mispredictions"] = mispredictions with open(os.path.join(model_path, f"errors_{ds.name}.json"), "w") as f: f.write(json.dumps(error_report) + "\n") with open(os.path.join(model_path, f"results.json"), "a+") as f: results = dict() results["ds"] = ds.name results["results"] = report results["fdr"] = self._fdr(y, y_pred) results["fttar"] = self._fttar(y, y_pred) f.write(json.dumps(results) + "\n")
[docs] def train(self, seed: int = SEED) -> None: """Executes complete model training workflow with evaluation and persistence. Performs hyperparameter optimization, model training, evaluation on test set, and generates comprehensive analysis including model interpretation and performance reports across all datasets. Args: seed (int): Random seed for reproducible training results. """ if seed > 0: np.random.seed(seed) torch.manual_seed(seed) # Training model logger.info("Fit model.") self.model_pipeline.hyperparam_fit() logger.info("Save model") self._save_model() logger.info("Save scaler") self._save_scaler() logger.info("Validate test set") y_pred = self.model_pipeline.predict(self.model_pipeline.x_test) y_pred = [round(value) for value in y_pred] logger.info( classification_report(self.model_pipeline.y_test, y_pred, labels=[0, 1]) ) logger.info("Test model.") self.test() logger.info("Interpret model.") self.explain()
def _fttar(self, y_actual: list[int], y_pred: list[int]) -> float: """Calculates False Positive to True Positive Ratio (FTTAR) metric. Computes the ratio of false positives to true positives, which is useful for understanding the trade-off between detecting malicious domains and generating false alarms in DGA detection systems. Args: y_actual (list[int]): Ground truth binary labels. y_pred (list[int]): Predicted binary labels. Returns: float: FTTAR ratio (0 if no true positives detected). """ _, FP, _, TP = confusion_matrix(y_actual, y_pred, labels=[0, 1]).ravel() if (TP) == 0: logger.debug("WARNING: TP = 0") return 0 return FP / TP def _fdr(self, y_actual: list[int], y_pred: list[int]) -> float: """Calculates False Discovery Rate (FDR) for model evaluation. Computes the proportion of false positives among all positive predictions, which indicates the reliability of positive DGA detections in the model. Args: y_actual (list[int]): Ground truth binary labels. y_pred (list[int]): Predicted binary labels. Returns: float: FDR value (0 if no positive predictions made). """ _, FP, _, TP = confusion_matrix(y_actual, y_pred, labels=[0, 1]).ravel() if (FP + TP) == 0: logger.debug("WARNING: FP + TP = 0") return 0 return FP / (FP + TP) def _load_model(self): try: model_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum, f"{self.model_name}.pickle", ) with open(model_path, "rb") as input_file: self.model_pipeline.model.clf = pickle.load(input_file) except: logger.warning( f"Model could not be loaded. Model path is '{self.model_output_path}' or path incorrect." ) def _save_model(self): logger.info("Save trained model to a file.") with open( os.path.join( tempfile.gettempdir(), f"rf_{self.model_pipeline.trial.number}.pickle" ), "wb", ) as fout: pickle.dump(self.model_pipeline.model.clf, fout) self.model_checksum = self._sha256sum( os.path.join( tempfile.gettempdir(), f"rf_{self.model_pipeline.trial.number}.pickle" ) ) model_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum ) os.makedirs(model_path, exist_ok=True) with open(os.path.join(model_path, f"{self.model_name}.pickle"), "wb") as fout: pickle.dump(self.model_pipeline.model.clf, fout) def _load_scaler(self): try: scaler_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum, "scaler.pickle", ) scaler = joblib.load(scaler_path) logger.info("Scaler loaded successfully.") return scaler except: logger.warning( f"Scaler file not found. Model path is '{self.model_output_path}' or path incorrect." ) return StandardScaler() def _save_scaler(self): """ Save the scaler for future use. """ scaler_path = os.path.join( self.model_output_path, self.model_name, self.model_checksum ) os.makedirs(scaler_path, exist_ok=True) with open(os.path.join(scaler_path, "scaler.pickle"), "wb") as f: pickle.dump(self.scaler, f) def _sha256sum(self, file_path: str) -> str: """Calculates SHA256 checksum for model file integrity verification. Args: file_path (str): Path to the model file to checksum. Returns: str: SHA256 hexadecimal digest for file validation. """ h = hashlib.sha256() with open(file_path, "rb") as file: while True: # Reading is buffered, so we can read smaller chunks. chunk = file.read(h.block_size) if not chunk: break h.update(chunk) return h.hexdigest()
_ds_options = [ click.option( "--dataset", "dataset", default="combine", type=click.Choice(["combine", "dgarchive", "cic", "dgta"]), help="Data set to train model, choose between all available datasets, DGArchive, CIC and DGTA.", ), click.option( "--dataset_path", "dataset_path", type=click.Path(exists=True), help="Dataset path, follow folder structure.", ), click.option( "--dataset_max_rows", "dataset_max_rows", default=-1, type=int, help="Maximum rows to load from each dataset.", ), ] @click.group(context_settings=CONTEXT_SETTINGS) def cli(): click.secho("Train heiDGAF CLI") @cli.command() @add_options(_ds_options) @click.option( "--model", "model", type=click.Choice(["xg", "rf", "gbm"]), help="Model to train, choose between XGBoost and RandomForest classifier", ) @click.option( "--model_output_path", "model_output_path", type=click.Path(), default=f"./{RESULT_FOLDER}/model", help="Model output path. Stores model with {{MODEL}}_{{SHA256}}.pickle.", ) def train( dataset: str, dataset_path: str, dataset_max_rows: int, model: str, model_output_path: str, ) -> None: trainer = DetectorTraining( model_name=model, dataset=dataset, data_base_path=dataset_path, max_rows=dataset_max_rows, model_output_path=model_output_path, ) trainer.train() @cli.command() @add_options(_ds_options) @click.option( "--model", "model", type=click.Choice(["xg", "rf", "gbm"]), help="Model to train, choose between XGBoost and RandomForest classifier", ) @click.option( "--model_path", "model_path", type=click.Path(exists=True), help="Model path.", ) def test( dataset: str, dataset_path: str, dataset_max_rows: int, model: str, model_path: str, ) -> None: trainer = DetectorTraining( dataset=dataset, data_base_path=dataset_path, max_rows=dataset_max_rows, model_output_path=model_path, model_name=model, ) trainer.test() @cli.command() @add_options(_ds_options) @click.option( "--model", "model", type=click.Choice(["xg", "rf", "gbm"]), help="Model to train, choose between XGBoost and RandomForest classifier", ) @click.option( "--model_path", "model_path", type=click.Path(exists=True), help="Model path.", ) def explain( dataset: str, dataset_path: str, dataset_max_rows: int, model: str, model_path: str, ) -> None: trainer = DetectorTraining( dataset=dataset, data_base_path=dataset_path, max_rows=dataset_max_rows, model_output_path=model_path, model_name=model, ) trainer.explain() if __name__ == "__main__": # pragma: no cover cli()