Source code for neuro_fuzzy_toolbox.training.sonfis

import torch

from neuro_fuzzy_toolbox.training import (
    base_model_trainer
)

import pandas as pd
import numpy as np

[docs] class SONFIS(base_model_trainer): """ Self-Organizing Neuro-Fuzzy Inference System (SONFIS) algorithm. Combines a parameter learning algorithm with three structural adaptation operators (GrowNet, SplitSubNet, and VanishNet) to iteratively update both the parameters and the structure of a rule-reduced ANFIS model. Note: This algorithm is only applicable to instances of :class:`rule_reduced_ANFIS`. """
[docs] def __init__(self, Ngrow, dGrow, Nsplit, eSplit, Nvanish, lVanish, max_iterations, ANFIStrainer, early_stopping=None, lse_for_new_consequents=False, lse_for_new_consequents_lambda=0., last_training_iteration=False): """ Initializes a new SONFIS instance. Args: Ngrow (int): Minimum number of poorly modeled samples required to grow a new subnet. dGrow (float): Threshold for identifying poorly modeled samples. A sample is considered poorly modeled if its maximum firing level across all subnets is less than or equal to this value raised to the power of the input dimensionality. Nsplit (int): Minimum number of samples associated with a subnet for it to be considered for splitting. eSplit (float): Minimum loss value of the samples associated with a subnet for it to be considered for splitting. Nvanish (int): Maximum number of samples associated with a subnet below which its age counter is incremented. lVanish (int): Maximum age of a subnet before it is removed. max_iterations (int): Maximum number of structural adaptation iterations. ANFIStrainer (base_model_trainer): Instantiated ANFIS training algorithm that defines how the model parameters are updated at each iteration. early_stopping (EarlyStopping): Early stopping mechanism to use during the SONFIS iterations. Defaults to ``None``. lse_for_new_consequents (bool): If ``True``, the consequent parameters of rules generated by GrowNet or SplitSubNet are initialized using least-squares estimation instead of random initialization. Defaults to ``False``. lse_for_new_consequents_lambda (float): Lambda value for Ridge regularization in the least-squares initialization of new consequent parameters. If ``0.``, no regularization is applied. Defaults to ``0.``. last_training_iteration (bool): If ``True``, performs a final parameter update over all subnets after the SONFIS algorithm finishes. Defaults to ``False``. """ # ------------- SONFIS ------------- # Hyperparameters self.Ngrow = Ngrow self.dGrow = dGrow self.Nsplit = Nsplit self.eSplit = eSplit self.Nvanish = Nvanish self.lVanish = lVanish self.max_iterations = max_iterations self.lse_for_new_consequents = lse_for_new_consequents self.lse_for_new_consequents_lambda = lse_for_new_consequents_lambda self.last_training_iteration = last_training_iteration # Early stopping self.sonfis_early_stopping = early_stopping # history self.history = {"loss": []} self.val_history = {"loss": []} # --------- ANFIS trainer --------- self.trainer = ANFIStrainer self.loss_function = self.trainer.loss_function # Optimizer self._optimizer_instance = None # ------ Internal variables & methods ------ self._freezed = torch.tensor([], dtype=torch.int) self._ages = torch.tensor([], dtype=torch.int) self._get_max_firing_level = None # ------ Rules Dataframe ------ self._rules_dataframe = None self._current_best_rules_dataframe = None self._current_max_idx = 0 self._best_rules_dataframe_iter = None
[docs] def __call__(self, ANFISmodel, train_loader, val_loader=None, verbose=True): """ Runs the SONFIS algorithm. At each iteration, the structural adaptation operators GrowNet, SplitSubNet, and VanishNet are applied, followed by a parameter update of the modified subnets. The algorithm stops when no structural updates occur or the maximum number of iterations is reached. If early stopping is enabled and a validation DataLoader is provided, the algorithm may also stop early based on the validation loss. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to train. train_loader (DataLoader): DataLoader containing the training data. val_loader (DataLoader): DataLoader containing the validation data. Defaults to ``None``. verbose (bool): If ``True``, prints progress and structural update messages at each iteration. Defaults to ``True``. """ if ANFISmodel._rule_reduced == False: raise ValueError('The ANFIS model must be rule reduced') self._set_max_firing_level_method(ANFISmodel) self._current_max_idx = ANFISmodel.rules - 1 self._register_loss(ANFISmodel, train_loader, val_loader) self._ages = torch.zeros(ANFISmodel.rules, dtype=torch.int) self._freezed = torch.zeros(ANFISmodel.rules, dtype=torch.int).bool() iter_width = len(str(self.max_iterations)) print(f'ITERATION: {0:{iter_width}}/{self.max_iterations}') self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed) if verbose: early_stop_flag = False self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True) if val_loader is not None: self._current_best_rules_dataframe = self._rules_dataframe.copy() self._best_rules_dataframe_iter = 0 print("\nSTARTING STATE:") print(self._rules_dataframe.to_string()) if val_loader is not None: print(f'\n\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}') else: print(f'\n\tloss: {self.history["loss"][-1]:.6f}') print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n') model_updated = True i = 0 while(model_updated and i < self.max_iterations): print(f'\nITERATION: {i+1:{iter_width}}/{self.max_iterations}') self._freeze_subnets() model_updated = self._update_structure(ANFISmodel, train_loader, verbose) if model_updated: self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed) if verbose: self._replace_trained_subnets_on_rules_dataframe(ANFISmodel) print("\nCURRENT STATE:") print(self._rules_dataframe.to_string()) self._register_loss(ANFISmodel, train_loader, val_loader) if val_loader is not None: print(f'\n\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}') else: print(f'\n\tloss: {self.history["loss"][-1]:.6f}') print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n') if (val_loader is not None) and (self.sonfis_early_stopping is not None): if self._check_early_stop(ANFISmodel, self.val_history["loss"][-1]): print(f'found on the {self._best_rules_dataframe_iter}° iteration.') early_stop_flag = True break else: if self.sonfis_early_stopping._counter == 0: self._current_best_rules_dataframe = self._rules_dataframe.copy() self._best_rules_dataframe_iter = i+1 else: print('NO MORE UPDATES') i += 1 if i == self.max_iterations: print('MAX ITERATIONS REACHED') self._unfreeze_subnets() if self.last_training_iteration: print('\nLast training iteration (all subnets are being trained again)') self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed) if verbose: if early_stop_flag: self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True) else: self._replace_trained_subnets_on_rules_dataframe(ANFISmodel) print("LAST TRAINING STATE:") print(self._rules_dataframe.to_string()) self._register_loss(ANFISmodel, train_loader, val_loader) if val_loader is not None: print(f'\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}') else: print(f'\tloss: {self.history["loss"][-1]:.6f}') print('\nTRAINING FINISHED') print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n') if verbose: if early_stop_flag: self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True) print(self._rules_dataframe.to_string())
# ----- Freezed subnets ----- def _freeze_subnets(self): """ Freezes all subnets, preventing their parameters from being updated during the next training step. """ self._freezed = torch.ones_like(self._freezed).bool() def _unfreeze_subnets(self): """ Unfreezes all subnets, allowing their parameters to be updated during the next training step. """ self._freezed = torch.zeros_like(self._freezed).bool() # ----- Early Stopping ----- def _check_early_stop(self, ANFISmodel, loss): """ Checks whether the SONFIS early stopping criterion is met. If early stopping is triggered, the subnet ages and frozen states are reset and the early stopping mechanism is reinitialized. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained. loss (float): Current validation loss value. Returns: bool: ``True`` if training should be stopped, ``False`` otherwise. """ if self.sonfis_early_stopping is not None: self.sonfis_early_stopping(ANFISmodel, loss, verbose=True) if self.sonfis_early_stopping.stop: self._ages = torch.zeros(ANFISmodel.rules, dtype=torch.int) self._freezed = torch.zeros(ANFISmodel.rules, dtype=torch.int).bool() self.sonfis_early_stopping.reset() return True return False # ----- Rules dataframe ----- def _add_subnets_on_rules_dataframe(self, new_premises, new_consequents): """ Appends rows for newly added subnets to the internal rules DataFrame. Args: new_premises (torch.Tensor): Premise parameters of the new subnets, of shape ``(input_size, n_new_subnets, mf_params)``. new_consequents (torch.Tensor): Consequent parameters of the new subnets, of shape ``(outputs, n_new_subnets, input_size + 1)``. """ n_new_subnets = new_premises.shape[1] new_premises = new_premises.permute(1, 0, 2).reshape(n_new_subnets, -1) new_consequents = new_consequents.permute(1, 0, 2).reshape(n_new_subnets, -1) data_block = torch.cat([new_premises, new_consequents], dim=1).cpu().numpy() new_subnets_df = pd.DataFrame(data_block, columns=self._rules_dataframe.columns) start = self._current_max_idx + 1 new_subnets_df.index = [i for i in range(start, start + n_new_subnets)] self._current_max_idx = new_subnets_df.index[-1] if self._rules_dataframe.empty: self._rules_dataframe = new_subnets_df else: self._rules_dataframe = pd.concat([self._rules_dataframe, new_subnets_df], axis=0) def _drop_subnets_on_rules_dataframe(self, mask): """ Removes rows corresponding to subnets indicated by a boolean mask from the internal rules DataFrame. Args: mask (torch.Tensor): Boolean tensor of length ``num_rules``, where ``True`` indicates that the corresponding subnet should be removed. """ keep_mask = ~mask.cpu().numpy() self._rules_dataframe = self._rules_dataframe.loc[keep_mask] def _drop_subnets_on_rules_dataframe_by_idxs(self, idxs_tensor): """ Removes rows corresponding to subnets at the specified indices from the internal rules DataFrame. Args: idxs_tensor (torch.Tensor): Tensor of integer indices indicating which subnets should be removed. """ idxs = idxs_tensor.cpu().numpy() keep = np.ones(len(self._rules_dataframe), dtype=bool) keep[idxs] = False self._rules_dataframe = self._rules_dataframe.iloc[keep] def _replace_trained_subnets_on_rules_dataframe(self, ANFISmodel): """ Updates the rows of the internal rules DataFrame corresponding to subnets that were not frozen during the last training step, reflecting their current parameter values. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model whose updated subnet parameters are used. """ mask = ~self._freezed mask_np = mask.cpu().numpy().astype(bool) new_premises = ANFISmodel.get_premises()[:, mask, :] new_consequents = ANFISmodel.get_consequents()[:, mask, :] n_new_subnets = new_premises.shape[1] if n_new_subnets == 0: return new_premises = new_premises.permute(1, 0, 2).reshape(n_new_subnets, -1) new_consequents = new_consequents.permute(1, 0, 2).reshape(n_new_subnets, -1) data_block = torch.cat([new_premises, new_consequents], dim=1).cpu().numpy() self._rules_dataframe.iloc[mask_np, :] = data_block # ----- Internal Methods ----- def _set_max_firing_level_method(self, ANFISmodel): """ Assigns the method used to retrieve the maximum firing level per sample, depending on whether the model uses a default rule. If the model has a default rule, the last firing level (corresponding to the default rule) is excluded from the maximum computation. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained. """ if ANFISmodel._default_rule: self._get_max_firing_level = lambda firing_levels: torch.max(firing_levels[:, :-1], dim=1) else: self._get_max_firing_level = lambda firing_levels: torch.max(firing_levels, dim=1) # ----- Update structure ----- def _update_structure(self, ANFISmodel, train_loader, verbose): """ Executes the GrowNet, SplitSubNet, and VanishNet structural adaptation operators in sequence. GrowNet is attempted first. If no new subnets are grown, SplitSubNet is attempted. VanishNet is always applied regardless of the outcomes of the previous operators. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update. train_loader (DataLoader): DataLoader containing the training data. verbose (bool): If ``True``, prints messages about structural changes. Returns: bool: ``True`` if any structural update was performed, ``False`` otherwise. """ did_Grow = self._GrowNet(ANFISmodel, train_loader, verbose) did_Split = False if not did_Grow: did_Split = self._SplitSubNet(ANFISmodel, train_loader, verbose) did_Vanish = self._VanishNet(ANFISmodel, train_loader, verbose) return did_Grow or did_Split or did_Vanish def _GrowNet(self, ANFISmodel, train_loader, verbose): """ Executes the GrowNet operator to add new subnets to the model. Identifies poorly modeled samples — those whose maximum firing level across all subnets is below ``dGrow ** input_size`` — and groups them by their associated subnet. For each group with more than ``Ngrow`` samples, a new subnet is created centered on the mean of those samples. If ``lse_for_new_consequents`` is enabled, the consequent parameters of the new subnets are estimated using least squares. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update. train_loader (DataLoader): DataLoader containing the training data. verbose (bool): If ``True``, prints messages about the new subnets added. Returns: bool: ``True`` if at least one new subnet was added, ``False`` otherwise. """ X = train_loader.dataset.tensors[0] y = train_loader.dataset.tensors[1] firing_levels = ANFISmodel.get_firing_levels(X) max_fl = self._get_max_firing_level(firing_levels) dGrowMask = max_fl.values <= self.dGrow**ANFISmodel._input_size # using dGrow bad_samples = X[dGrowMask] bad_targets = y[dGrowMask] clusters = max_fl.indices[dGrowMask] # clusters are the associated "bad samples" to each subnet based on the max firing level unique_rules, counts = torch.unique(clusters, return_counts=True) # how many "max firing levels" do each of the subnets get considering only the "bad_samples"? Ngrow_mask = counts > self.Ngrow indices_to_keep = torch.isin(clusters, unique_rules[Ngrow_mask]).nonzero().squeeze() # using Ngrow bad_samples = bad_samples[indices_to_keep] # getting which samples will be considered bad_targets = bad_targets[indices_to_keep] clusters = clusters[indices_to_keep] # & the associated subnet if bad_samples.size(0) == 0: return False else: rules = [clusters == rule for rule in torch.unique(clusters)] # list of boolean masks (lenght: current number of subnets), each one with shape: (bad_samples.shape[0], ) means = torch.stack([bad_samples[rule].mean(dim=0) for rule in rules]) # shape = (new_subnets, input_dim) stds = torch.stack([bad_samples[rule].std(dim=0) for rule in rules]) # shape = (new_subnets, input_dim) new_premises = ANFISmodel._fuzzification_layer._membership_function._grow_new_premise_parameters(means, stds) ANFISmodel.set_premises(torch.cat((ANFISmodel.get_premises(), new_premises), dim=1)) n_new_rules = new_premises.shape[1] new_consequents = ANFISmodel._consequent_layer._consequent_function.random_consequents(ANFISmodel._outputs, n_new_rules, ANFISmodel._input_size, ANFISmodel._dtype) ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents(), new_consequents), dim=1)) if self.lse_for_new_consequents: # if True, consequents are init with LSE new_consequents = self._lse_after_GrowNet(ANFISmodel, bad_samples, bad_targets, rules, n_new_rules) ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents()[:, :-n_new_rules, :], new_consequents), dim=1)) if verbose: self._add_subnets_on_rules_dataframe(new_premises, new_consequents) print(f"\t-> Growing {n_new_rules} new subnets: {[i for i in range(self._current_max_idx - new_premises.shape[1] + 1, self._current_max_idx + 1)]}") self._ages = torch.cat((self._ages, torch.zeros(new_premises.shape[1], dtype=torch.int))) self._freezed = torch.cat((self._freezed, torch.zeros(new_premises.shape[1], dtype=torch.int).bool())) return True def _lse_after_GrowNet(self, ANFISmodel, samples, targets, rules_mask, n_new_rules): """ Estimates the consequent parameters of the subnets added by GrowNet using least-squares estimation. Each new subnet's consequents are estimated independently using only the samples associated with that subnet. Ridge regularization is applied if ``lse_for_new_consequents_lambda`` is greater than ``0.``. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained. samples (torch.Tensor): Input samples considered for the creation of the new subnets, of shape ``(n_samples, input_size)``. targets (torch.Tensor): Targets associated with the input samples, of shape ``(n_samples,)`` or ``(n_samples, outputs)``. rules_mask (list[torch.Tensor]): List of boolean masks of length ``n_new_rules``, each indicating which samples are associated with the corresponding new subnet. n_new_rules (int): Number of new subnets added by GrowNet. Returns: torch.Tensor: Estimated consequent parameters for the new subnets, of shape ``(outputs, n_new_rules, input_size + 1)``. """ new_consequents = torch.tensor([]) i = 0 for rule in rules_mask: x = samples[rule] y = targets[rule] w_norm = ANFISmodel.get_firing_levels(x, normalized=True) xe = torch.cat([x, torch.ones(x.shape[0], 1)], dim=1) fs = w_norm[:, i - n_new_rules].unsqueeze(0).t() '''preliminary fix for the dtype issue''' if ANFISmodel._output_type == 'softmax': y = y.to(torch.int64) y = torch.nn.functional.one_hot(y, ANFISmodel._outputs) if y.dtype != xe.dtype: y = y.to(xe.dtype) '''preliminary fix for the dtype issue''' A = xe * fs if self.lse_for_new_consequents_lambda > 0.: p = A.shape[1] I = torch.eye(p, dtype=A.dtype) * torch.sqrt(torch.tensor(self.lse_for_new_consequents_lambda, dtype=A.dtype)) A = torch.cat([A, I], dim=0) if y.dim() > 1: m = y.shape[1] zeros = torch.zeros((p, m), dtype=A.dtype) else: zeros = torch.zeros(p, dtype=A.dtype) y = torch.cat([y, zeros], dim=0) C, _, _, _ = torch.linalg.lstsq(A, y, rcond=None, driver="gelsd") new_consequents = torch.cat((new_consequents, C.t().reshape(ANFISmodel._outputs, 1, xe.shape[1])), dim=1) i += 1 return new_consequents def _SplitSubNet(self, ANFISmodel, train_loader, verbose): """ Executes the SplitSubNet operator to split subnets with high local error. Identifies subnets that have more than ``Nsplit`` associated samples and a loss value above ``eSplit``, and splits each of them into two new subnets. If ``lse_for_new_consequents`` is enabled, the consequent parameters of the resulting subnets are estimated using least squares. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update. train_loader (DataLoader): DataLoader containing the training data. verbose (bool): If ``True``, prints messages about the subnets split. Returns: bool: ``True`` if at least one subnet was split, ``False`` otherwise. """ samples = train_loader.dataset.tensors[0] targets = train_loader.dataset.tensors[1] firing_levels = ANFISmodel.get_firing_levels(samples) max_fl = self._get_max_firing_level(firing_levels) best_rules = max_fl.indices with torch.no_grad(): model_outputs = ANFISmodel(samples, return_probs=False) # get model preds unique_rules, counts = torch.unique(best_rules, return_counts=True) # how many "max firing levels" do each of the subnets get? Nsplit_mask = counts > self.Nsplit indices_to_keep = torch.isin(best_rules, unique_rules[Nsplit_mask]).nonzero().squeeze() # using Nsplit if indices_to_keep.size(0) == 0: return False else: # collect the samples, targets, outputs and the "best rule associated" (based on the max firing level) to be considered model_outputs = model_outputs[indices_to_keep] samples = samples[indices_to_keep] targets = targets[indices_to_keep] best_rules = best_rules[indices_to_keep] unique_rules = torch.unique(best_rules) if targets.dtype != train_loader.dataset.tensors[1].dtype: targets = targets.to(train_loader.dataset.tensors[1].dtype) rules = [best_rules == rule for rule in unique_rules] # list of boolean masks (lenght: current number of subnets), each one with shape: (bad_samples.shape[0], ) model_outputs_list = [model_outputs[rule] for rule in rules] targets_list = [targets[rule] for rule in rules] # compute loss if ANFISmodel._output_type == "softmax" and ANFISmodel._custom_classes: # if classes are not [0, 1, 2, ...] loss_values = torch.stack([self.loss_function(model_outputs_list[i], torch.searchsorted(ANFISmodel.classes, targets_list[i]).long()) for i in range(len(rules))]) # for each of the considered subnets with ONLY its associated samples else: loss_values = torch.stack([self.loss_function(model_outputs_list[i], targets_list[i]) for i in range(len(rules))]) # for each of the considered subnets with ONLY its associated samples eSplit_mask = loss_values > self.eSplit rules_to_split = unique_rules[eSplit_mask] if ((targets.shape[0] == 0) or (rules_to_split.shape[0] == 0)): # using eSplit return False else: if self.lse_for_new_consequents: to_split_samples_list, to_split_targets_list = self._group_samples_for_lse_in_order(eSplit_mask, targets_list, samples, rules) new_premises = ANFISmodel.get_premises() new_consequents = ANFISmodel.get_consequents() all_new_premises = torch.tensor([]) all_new_consequents = torch.tensor([]) idx = 0 for rule in list(torch.flip(rules_to_split, [0]).long()): # using eSplit new_premises = torch.cat((new_premises[:, :rule,:], new_premises[:, rule+1:, :]), dim=1) to_split = ANFISmodel.get_premises()[:, rule:rule+1, :] split = ANFISmodel._fuzzification_layer._membership_function._split_premise_parameters(to_split) new_premises = torch.cat((new_premises, split), dim=1) new_consequents = torch.cat((new_consequents[:, :rule, :], new_consequents[:, rule+1:, :]), dim=1) new_consequent_to_add = ANFISmodel._consequent_layer._consequent_function.random_consequents(ANFISmodel._outputs, 2, ANFISmodel._input_size, ANFISmodel._dtype) new_consequents = torch.cat((new_consequents, new_consequent_to_add), dim=1) self._ages = torch.cat((self._ages[:rule], self._ages[rule+1:])) self._freezed = torch.cat((self._freezed[:rule], self._freezed[rule+1:])) self._ages = torch.cat((self._ages, torch.zeros(2, dtype=torch.int))) self._freezed = torch.cat((self._freezed, torch.zeros(2, dtype=torch.int).bool())) if verbose: all_new_premises = torch.cat((all_new_premises, split), dim=1) all_new_consequents = torch.cat((all_new_consequents, new_consequent_to_add), dim=1) ANFISmodel.set_premises(new_premises) ANFISmodel.set_consequents(new_consequents) if self.lse_for_new_consequents: new_2_last_consequents_to_replace = self._lse_while_SplitSubNet(ANFISmodel, to_split_samples_list[idx], to_split_targets_list[idx]) ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents()[:, :-2, :], new_2_last_consequents_to_replace), dim=1)) new_premises = ANFISmodel.get_premises() new_consequents = ANFISmodel.get_consequents() idx += 1 if verbose: subnets = rules_to_split.tolist() if rules_to_split[rules_to_split == True].size(0) == 1: subnets = [subnets] #if isinstance(self._rules_dataframe.index[subnets[0]], int): # #print(f"\t-> self._rules_dataframe.index[i]: {self._rules_dataframe.index[subnets[0]]}") # print(f"\t-> Splitting {rules_to_split.shape[subnets[0]]} subnets: {[self._rules_dataframe.index[i] for i in subnets]}") #else: # print(f"\t-> Splitting {rules_to_split.shape[0]} subnets: {[self._rules_dataframe.index[i].item() for i in subnets]}") print(f"\t-> Splitting {rules_to_split.shape[0]} subnets: {subnets}") self._drop_subnets_on_rules_dataframe_by_idxs(rules_to_split) self._add_subnets_on_rules_dataframe(all_new_premises, all_new_consequents) return True def _group_samples_for_lse_in_order(self, eSplit_mask, targets_list, samples, rules): """ Groups samples and targets in reverse order for least-squares estimation during the SplitSubNet operation. Processes subnets in reverse order to match the order in which they are split, ensuring consistency between the sample groups and the subnets being updated. Args: eSplit_mask (torch.Tensor): Boolean tensor indicating which subnets exceed the ``eSplit`` threshold and are to be split. targets_list (list[torch.Tensor]): List of target tensors, one per subnet considered for splitting. samples (torch.Tensor): Input samples of the subnets considered for splitting, of shape ``(n_samples, input_size)``. rules (list[torch.Tensor]): List of boolean masks indicating which samples are associated with each subnet. Returns: tuple[list[torch.Tensor], list[torch.Tensor]]: A tuple containing two lists (the grouped input samples and their corresponding targets) ordered to match the reverse splitting sequence. """ to_split_samples_list = [] to_split_targets_list = [] i = eSplit_mask.shape[0] - 1 for boolean in torch.flip(eSplit_mask, [0]): if boolean: to_split_targets_list.append(targets_list[i]) to_split_samples_list.append(samples[rules[i]]) i -= 1 return to_split_samples_list, to_split_targets_list def _lse_while_SplitSubNet(self, ANFISmodel, samples, targets): """ Estimates the consequent parameters of the two subnets resulting from a SplitSubNet operation using least-squares estimation. The estimation uses only the samples associated with the subnet that was split. Ridge regularization is applied if ``lse_for_new_consequents_lambda`` is greater than ``0.``. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained. samples (torch.Tensor): Input samples associated with the subnet that was split, of shape ``(n_samples, input_size)``. targets (torch.Tensor): Targets associated with those samples, of shape ``(n_samples,)`` or ``(n_samples, outputs)``. Returns: torch.Tensor: Estimated consequent parameters for the two new subnets, of shape ``(outputs, 2, input_size + 1)``. """ new_consequents = torch.tensor([]) x = samples y = targets w_norm = ANFISmodel.get_firing_levels(x, normalized=True) x = torch.cat([x, torch.ones(x.shape[0], 1)], dim=1) w_norm = w_norm[:, -2:].unsqueeze(2).repeat(1, 1, x.shape[1]).view(w_norm[:, -2:].shape[0], -1) X = x.repeat(1, 2) '''preliminary fix for the dtype issue''' if ANFISmodel._output_type == 'softmax': y = y.to(torch.int64) y = torch.nn.functional.one_hot(y, ANFISmodel._outputs) if y.dtype != X.dtype: y = y.to(X.dtype) '''preliminary fix for the dtype issue''' A = X * w_norm if self.lse_for_new_consequents_lambda > 0.: p = A.shape[1] I = torch.eye(p, dtype=A.dtype) * torch.sqrt(torch.tensor(self.lse_for_new_consequents_lambda, dtype=A.dtype)) A = torch.cat([A, I], dim=0) if y.dim() > 1: m = y.shape[1] zeros = torch.zeros((p, m), dtype=A.dtype) else: zeros = torch.zeros(p, dtype=A.dtype) y = torch.cat([y, zeros], dim=0) C, _, _, _ = torch.linalg.lstsq(A, y, rcond=None, driver="gelsd") new_consequents = C.t().reshape(ANFISmodel._outputs, 2, x.shape[1]) return new_consequents def _VanishNet(self, ANFISmodel, train_loader, verbose): """ Executes the VanishNet operator to remove underused subnets. Tracks the age of each subnet by incrementing its counter whenever fewer than ``Nvanish`` samples are associated with it, and resetting it otherwise. Subnets whose age reaches or exceeds ``lVanish`` and that still have fewer than ``Nvanish`` associated samples are removed from the model. Args: ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update. train_loader (DataLoader): DataLoader containing the training data. verbose (bool): If ``True``, prints messages about the subnets removed. Returns: bool: ``True`` if at least one subnet was removed, ``False`` otherwise. """ X = train_loader.dataset.tensors[0] firing_levels = ANFISmodel.get_firing_levels(X) max_fl = self._get_max_firing_level(firing_levels) best_rules = max_fl.indices unique_rules, counts = torch.unique(best_rules, return_counts=True) # how many "max firing levels" do each of the subnets get? all_rules = torch.arange(ANFISmodel.rules) total_counts = torch.zeros(ANFISmodel.rules, dtype=torch.int64) total_counts[unique_rules] = counts self._ages[all_rules[(total_counts < self.Nvanish)]] += 1 # add 1 age to the subnets that have less than Nvanish "associated" samples self._ages[all_rules[(total_counts >= self.Nvanish)]] = 0 # reset age to 0 for the subnets that have more than Nvanish "associated" samples mask = ((self._ages >= self.lVanish) & (total_counts < self.Nvanish)) # using Nvanish & lVanish --> to filter by age 6 by number of "associated" samples rules_to_eliminate = all_rules[mask] if torch.equal(rules_to_eliminate, torch.tensor([], dtype=torch.int64)): # if there ARE NOT rules to eliminate return False else: # if there ARE rules to eliminate new_premises = ANFISmodel.get_premises() new_consequents = ANFISmodel.get_consequents() for rule in torch.flip(rules_to_eliminate, dims=(0,)): new_premises = torch.cat((new_premises[:, :rule, :], new_premises[:, rule+1:, :]), dim=1) new_consequents = torch.cat((new_consequents[:, :rule, :], new_consequents[:, rule+1:, :]), dim=1) self._ages = torch.cat((self._ages[:rule], self._ages[rule+1:])) self._freezed = torch.cat((self._freezed[:rule], self._freezed[rule+1:])) ANFISmodel.set_premises(new_premises) ANFISmodel.set_consequents(new_consequents) if verbose: subnets = (mask.nonzero().squeeze()).tolist() if mask[mask == True].size(0) == 1: subnets = [subnets] print(f"\t-> Vanishing {rules_to_eliminate.size(0)} subnets: {[self._rules_dataframe.index[i].item() for i in subnets]}") self._drop_subnets_on_rules_dataframe(mask) return True