Source code for train.explainer

import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from scipy.stats import ks_2samp
from scipy.stats import wasserstein_distance
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from te2rules.explainer import ModelExplainer

sys.path.append(os.getcwd())
from src.base.log_config import get_logger
from src.train import RESULT_FOLDER, SEED
from src.train.dataset import Dataset

logger = get_logger("train.explainer")


[docs] class Plotter: """Creates visualizations and plots for dataset analysis and model interpretation Generates various plots including PCA visualizations, t-SNE projections, label distributions, and feature analysis plots to understand dataset characteristics and model behavior in DGA detection tasks. """
[docs] def __init__(self, output_path: str = f"./{RESULT_FOLDER}/data") -> None: """ Args: output_path (str): Directory path to save generated visualization files. """ self.output_path = output_path
def _plot_pca_2d(self, X: np.ndarray, y: np.ndarray, name: str) -> None: """Creates 2D PCA visualization of feature data. Reduces dimensionality to 2D using PCA and generates scatter plot with different colors and markers for each class to visualize data separation. Args: X (np.ndarray): Feature matrix for dimensionality reduction. y (np.ndarray): Class labels for color coding. name (str): Dataset name for output file naming. """ pca = PCA(n_components=2) X_pca = pca.fit_transform(X) Xax = X_pca[:, 0] Yax = X_pca[:, 1] cdict = {0: "red", 1: "green"} labl = {0: "C1", 1: "C2"} marker = {0: "*", 1: "o"} alpha = {0: 0.3, 1: 0.5} fig = plt.figure(figsize=(7, 5)) ax = fig.add_subplot(111) fig.patch.set_facecolor("white") # plt.scatter(X[:,0],X[:,1],c=y,cmap='plasma') for l in np.unique(y): ix = np.where(y == l) ax.scatter( Xax[ix], Yax[ix], c=cdict[l], s=40, label=labl[l], marker=marker[l], alpha=alpha[l], ) plt.xlabel("First principal component") plt.ylabel("Second Principal Component") plt.show() output_path = os.path.join(self.output_path, name) os.makedirs(output_path, exist_ok=True) plt.savefig(os.path.join(output_path, f"pca_2d.pdf")) plt.close() def _plot_pca_3d(self, X: np.ndarray, y: np.ndarray, name: str) -> None: """Creates 3D PCA visualization of feature data. Reduces dimensionality to 3D using PCA and generates 3D scatter plot showing class separation across the first three principal components. Args: X (np.ndarray): Feature matrix for dimensionality reduction. y (np.ndarray): Class labels for color coding. name (str): Dataset name for output file naming. """ pca = PCA(n_components=3) pca.fit(X) X_pca = pca.transform(X) ex_variance = np.var(X_pca, axis=0) ex_variance_ratio = ex_variance / np.sum(ex_variance) ex_variance_ratio Xax = X_pca[:, 0] Yax = X_pca[:, 1] Zax = X_pca[:, 2] cdict = {0: "red", 1: "green"} labl = {0: "C1", 1: "C2"} marker = {0: "*", 1: "o"} alpha = {0: 0.3, 1: 0.5} fig = plt.figure(figsize=(7, 5)) ax = fig.add_subplot(111, projection="3d") fig.patch.set_facecolor("white") for l in np.unique(y): ix = np.where(y == l) ax.scatter( Xax[ix], Yax[ix], Zax[ix], c=cdict[l], s=40, label=labl[l], marker=marker[l], alpha=alpha[l], ) # for loop ends ax.set_xlabel("First Principal Component", fontsize=14) ax.set_ylabel("Second Principal Component", fontsize=14) ax.set_zlabel("Third Principal Component", fontsize=14) ax.legend() plt.show() output_path = os.path.join(self.output_path, name) os.makedirs(output_path, exist_ok=True) plt.savefig(os.path.join(output_path, f"pca_3d.pdf")) plt.close() def _plot_tsne( self, X: np.ndarray, y: np.ndarray, ds_name: str, title: str = "t-SNE Plot", random_state: int = SEED, ): # Reduce dimensionality to 2D with t-SNE tsne = TSNE(n_components=2, random_state=random_state) X_tsne = tsne.fit_transform(X) # Plot plt.figure(figsize=(10, 8)) for label, color, name in zip( [0, 1], ["blue", "orange"], ["Benign", "Malicious"] ): idx = y == label plt.scatter(X_tsne[idx, 0], X_tsne[idx, 1], c=color, label=name, alpha=0.7) plt.title(title) plt.xlabel("t-SNE Component 1") plt.ylabel("t-SNE Component 2") plt.xticks([]) plt.yticks([]) plt.grid(False) plt.legend() plt.tight_layout() plt.show() output_path = os.path.join(self.output_path, ds_name) os.makedirs(output_path, exist_ok=True) plt.savefig(os.path.join(output_path, f"tsne.pdf")) plt.close() def _plot_label_distribution(self, data: pl.DataFrame, name: str) -> None: """Creates bar chart showing distribution of class labels in dataset. Visualizes the frequency distribution of different classes using logarithmic scale to handle imbalanced datasets effectively. Args: data (pl.DataFrame): Dataset containing class labels in 'class' column. name (str): Dataset name for output file naming. """ label_counts = data["class"].value_counts() label_distribution = dict(zip(label_counts["class"], label_counts["count"])) # Plot using matplotlib labels = list(label_distribution.keys()) counts = list(label_distribution.values()) plt.figure(figsize=(18, 9)) plt.bar(labels, counts, color=["skyblue", "salmon"]) plt.xlabel("Label", labelpad=30) plt.ylabel("Count") plt.yscale("log") plt.title("Label Distribution") plt.xticks(rotation=90, ha="center") plt.grid(axis="y", linestyle="--", alpha=0.6) plt.tight_layout() plt.show() output_path = os.path.join(self.output_path, name) os.makedirs(output_path, exist_ok=True) plt.savefig(os.path.join(output_path, f"label_distribution.pdf")) plt.close() def _remove_feature( self, component: int, X: np.ndarray, y: np.ndarray, pca: PCA, name: str ) -> None: """Visualizes data after removing specific principal component projection. Creates scatter plot showing how data appears when the influence of a particular principal component is removed, helping to understand component contributions. Args: component (int): Index of principal component to remove (0-based). X (np.ndarray): Original feature matrix. y (np.ndarray): Class labels for color coding. pca (PCA): Fitted PCA object containing component information. name (str): Dataset name for output file naming. """ # Remove PC1 Xmean = X - X.mean(axis=0) value = Xmean @ pca.components_[component] pc1 = value.reshape(-1, 1) @ pca.components_[component].reshape(1, -1) Xremove = X - pc1 plt.figure(figsize=(8, 6)) plt.scatter(Xremove[:, 0], Xremove[:, 1], c=y) plt.xlabel("0") plt.ylabel("1") plt.title("Two features from the dataset after removing PC1") plt.show() output_path = os.path.join(self.output_path, name) os.makedirs(output_path, exist_ok=True) plt.savefig(os.path.join(output_path, f"pca_pc{component + 1}.pdf")) plt.close()
[docs] def create_plots_binary( self, ds_X: list[np.ndarray], ds_y: list[np.ndarray], data: list[Dataset] ) -> None: """Generates comprehensive visualization suite for binary classification datasets. Creates PCA plots (2D/3D), t-SNE projections, principal component removal analysis, and label distribution charts for multiple datasets to understand data characteristics and class separability in binary DGA detection tasks. Args: ds_X (list[np.ndarray]): List of feature matrices for each dataset. ds_y (list[np.ndarray]): List of label arrays for each dataset. data (list[Dataset]): List of dataset objects containing metadata. """ for X, y, ds in zip(ds_X, ds_y, data): if "heicloud" in ds.name: tsne_x = X tsne_y = y for X, y, ds in zip(ds_X, ds_y, data): if "dgarchive" in ds.name: logger.info(f"Plot data for {ds.name}") self._plot_tsne( np.concatenate((X, tsne_x), axis=0), np.concatenate((y, tsne_y), axis=0), ds_name=f"{ds.name}", ) for X, y, ds in zip(ds_X, ds_y, data): if not "dgarchive" in ds.name: logger.info(f"Plot data for {ds.name}") self._plot_pca_2d(X=X, y=y, name=ds.name) self._plot_pca_3d(X=X, y=y, name=ds.name) # Show the principal components pca = PCA().fit(X) logger.info("Principal components:") logger.info(pca.components_) # Remove PC1 self._remove_feature(0, X, y, pca, name=ds.name) # Remove PC2 self._remove_feature(1, X, y, pca, name=ds.name) # Remove PC3 self._remove_feature(2, X, y, pca, name=ds.name) # print the explained variance ratio logger.info("Explainedd variance ratios:") logger.info(pca.explained_variance_ratio_) self._plot_label_distribution(ds.data, name=ds.name) self._plot_tsne(X, y, ds.name)
# df_data = data.to_pandas() # # Assuming your data is in a DataFrame called 'df' with a 'condition' column # condition1_data = df_data[df_data["class"] == 1] # condition2_data = df_data[df_data["class"] == 0] # # List of measurements (you can use all or a subset) # measurements = df_data.columns.tolist()[ # 1: # ] # [1:] to drop the condition column in the beginning # self._plot_data_distribution( # data_condition1=condition1_data, # data_condition2=condition2_data, # measurements=measurements, # condition1_name="Benign", # condition2_name="Malicious", # name=ds.name # )
[docs] def create_plots_multiclass( self, ds_X: list[np.ndarray], ds_y: list[np.ndarray], data: list[Dataset] ) -> None: """Generates visualizations for multiclass DGA family classification datasets. Creates specialized plots for datasets containing multiple DGA families, focusing on label distribution analysis to understand class imbalances and dataset composition for multiclass classification tasks. Args: ds_X (list[np.ndarray]): List of feature matrices for each dataset. ds_y (list[np.ndarray]): List of label arrays for each dataset. data (list[Dataset]): List of dataset objects containing class information. """ # Plot label distribution from DGArchive df_dgarchive_list = [] for X, y, ds in zip(ds_X, ds_y, data): if "dgarchive" in ds.name: df_dgarchive_list.append(ds.data) df_dgarchive = pl.concat(df_dgarchive_list) self._plot_label_distribution(df_dgarchive, name="dgarchive")
def _plot_data_distribution( self, data_condition1, data_condition2, measurements, ds_name, condition1_name="Condition 1", condition2_name="Condition 2", exclude_outliers=True, percentile=99, ) -> None: """ Analyzes the distributions of network traffic between two conditions. Calculates the Kolmogorov-Smirnov (KS) statistic and Earth Mover's Distance (EMD) for each measurement, and plots the distributions with outlier exclusion. Saves plots to a directory. Parameters: data_condition1 (pd.DataFrame): Data for first condition data_condition2 (pd.DataFrame): Data for second condition measurements (list): List of measurement names to analyze condition1_name (str): Name of first condition for plotting condition2_name (str): Name of second condition for plotting exclude_outliers (bool): Whether to exclude outliers based on percentile (default: True) percentile (int): Percentile threshold for outlier exclusion (default: 99) """ os.makedirs(self.output_path, exist_ok=True) results = [] # Calculate number of rows needed for subplots (6 columns) n_rows = int(np.ceil(len(measurements) / 6)) # Create figures for density plots and histograms fig_density, axs_density = plt.subplots(n_rows, 6, figsize=(36, 8 * n_rows)) fig_hist, axs_hist = plt.subplots(n_rows, 6, figsize=(36, 8 * n_rows)) # Flatten axes arrays for easier indexing axs_density = axs_density.flatten() if n_rows > 1 else [axs_density] axs_hist = axs_hist.flatten() if n_rows > 1 else [axs_hist] for i, measurement in enumerate(measurements): # Extract measurement data for both conditions arr1 = data_condition1[measurement].values arr2 = data_condition2[measurement].values # Replace inf values with NaN arr1 = np.where(np.isinf(arr1), np.nan, arr1) arr2 = np.where(np.isinf(arr2), np.nan, arr2) # Exclude outliers if requested if exclude_outliers: threshold1 = np.nanpercentile(arr1, percentile) threshold2 = np.nanpercentile(arr2, percentile) arr1 = np.where(arr1 > threshold1, np.nan, arr1) arr2 = np.where(arr2 > threshold2, np.nan, arr2) # Remove NaN values for statistical calculations arr1_clean = arr1[~np.isnan(arr1)] arr2_clean = arr2[~np.isnan(arr2)] # Calculate statistics if there's enough data if len(arr1_clean) > 0 and len(arr2_clean) > 0: ks_stat, ks_p_value = ks_2samp(arr1_clean, arr2_clean) emd = wasserstein_distance(arr1_clean, arr2_clean) else: ks_stat = ks_p_value = emd = np.nan # Store results results.append( { "Measurement": measurement, "KS_Statistic": ks_stat, "KS_p_value": ks_p_value, "EMD": emd, "N_condition1": len(arr1_clean), "N_condition2": len(arr2_clean), } ) # Create density plot ax_density = axs_density[i] if len(arr1_clean) > 0: sns.kdeplot( arr1_clean, fill=True, color="blue", label=condition1_name, ax=ax_density, warn_singular=False, ) if len(arr2_clean) > 0: sns.kdeplot( arr2_clean, fill=True, color="orange", label=condition2_name, ax=ax_density, warn_singular=False, ) ax_density.set_title(f"{measurement}", fontsize=10) ax_density.set_xlabel("Value", fontsize=8) ax_density.set_ylabel("Density", fontsize=8) ax_density.legend(fontsize=8) # Create histogram ax_hist = axs_hist[i] if len(arr1_clean) > 0: ax_hist.hist( arr1_clean, bins=30, color="blue", alpha=0.5, label=condition1_name, ) if len(arr2_clean) > 0: ax_hist.hist( arr2_clean, bins=30, color="orange", alpha=0.5, label=condition2_name, ) ax_hist.set_title(f"{measurement}", fontsize=10) ax_hist.set_xlabel("Value", fontsize=8) ax_hist.set_ylabel("Frequency", fontsize=8) ax_hist.legend(fontsize=8) # Remove empty subplots if any for j in range(i + 1, len(axs_density)): fig_density.delaxes(axs_density[j]) fig_hist.delaxes(axs_hist[j]) # Adjust layout and save plots fig_density.tight_layout() fig_density.subplots_adjust(hspace=0.4, wspace=0.3) fig_hist.tight_layout() fig_hist.subplots_adjust(hspace=0.4, wspace=0.3) plt.show() output_path = os.path.join(self.output_path, ds_name) os.makedirs(output_path, exist_ok=True) density_plot_path = os.path.join(output_path, "density_plots.pdf") hist_plot_path = os.path.join(output_path, "histogram_plots.pdf") fig_density.savefig(density_plot_path, dpi=300) fig_hist.savefig(hist_plot_path, dpi=300) logger.info(f"Density plots saved to: {density_plot_path}") logger.info(f"Histogram plots saved to: {hist_plot_path}") plt.close()
[docs] class Explainer: """Interprets and explains trained machine learning models for DGA detection. Provides model interpretation capabilities including rule extraction, feature importance analysis, and threshold rescaling for decision trees and ensemble models used in domain generation algorithm detection tasks. """
[docs] def __init__(self, output_path: str = f"./{RESULT_FOLDER}") -> None: """ Args: output_path (str): Directory path to save interpretation results. """ self.output_path = output_path
def __rescale_rule(self, rule: str, scaler, feature_names: list[str]) -> str: """Rescales feature thresholds in decision rules to original value ranges. Converts scaled feature thresholds back to their original (pre-scaled) values for better interpretability of decision tree rules and feature importance. Args: rule (str): Decision rule string with scaled thresholds (e.g., "feature1 > 0.5 and feature2 <= 1.3"). scaler (sklearn.preprocessing): Fitted scaler with inverse_transform method. feature_names (list[str]): Names of features in original order. Returns: str: Decision rule with thresholds in original value ranges. """ # If scaler is none, no rescaling is needed if scaler is None: return rule # Parse the rule to extract numeric values for i, feature in enumerate(feature_names): # Replace scaled values with rescaled values if feature in rule: # Find the numeric value associated with this feature in the rule # Example format: "feature > value" or "feature <= value" parts = rule.split() for j, part in enumerate(parts): if ( part == feature and j + 2 < len(parts) and parts[j + 1] in {">", "<=", "=", ">=", "<"} ): # Extract and rescale the numeric value try: scaled_value = float(parts[j + 2]) # Create a placeholder for the scaler with the same shape as x_test placeholder = [0] * len(feature_names) placeholder[i] = ( scaled_value # Set the scaled value for the current feature ) # Inverse transform using the scaler original_value = scaler.inverse_transform([placeholder])[ 0, i ] # Round the rescaled value to 3 digits rounded_value = round(original_value, 3) # Replace the scaled value in the rule with the original value parts[j + 2] = str(rounded_value) except ValueError: continue # Skip if the value is not numeric # Reconstruct the rule with rescaled values rule = " ".join(parts) return rule
[docs] def interpret_model( self, model, x_test: np.ndarray, y_test: np.ndarray, df_cols: list[str], scaler=None, ) -> list[str]: """ Interpret a trained model by extracting decision rules and optionally rescaling them. Args: model (sklearn.ensemble.BaseEnsemble | XGBClassifier): Trained ML model. x_test (np.ndarray): Test set features. y_test (np.ndarray): Test set labels. df_cols (list[str]): Column names of the features. model_name (str): Name used for saving output files. scaler (optional): Scaler used in preprocessing, e.g., StandardScaler. Defaults to None. """ print(df_cols) # Create TE2RULES explainer explainer = ModelExplainer(model, feature_names=df_cols) # Generate explanation rules = explainer.explain(x_test, y_test.tolist()) # Rescale values in rules if scaler is provided if scaler is not None: rules = [self.__rescale_rule(rule, scaler, df_cols) for rule in rules] return rules