import torch
from neuro_fuzzy_toolbox.training import (
base_model_trainer
)
import pandas as pd
import numpy as np
[docs]
class SONFIS(base_model_trainer):
"""
Self-Organizing Neuro-Fuzzy Inference System (SONFIS) algorithm.
Combines a parameter learning algorithm with three structural adaptation
operators (GrowNet, SplitSubNet, and VanishNet) to iteratively update
both the parameters and the structure of a rule-reduced ANFIS model.
Note:
This algorithm is only applicable to instances of :class:`rule_reduced_ANFIS`.
"""
[docs]
def __init__(self, Ngrow, dGrow, Nsplit, eSplit, Nvanish, lVanish, max_iterations, ANFIStrainer, early_stopping=None, lse_for_new_consequents=False, lse_for_new_consequents_lambda=0., last_training_iteration=False):
"""
Initializes a new SONFIS instance.
Args:
Ngrow (int): Minimum number of poorly modeled samples required to grow a new subnet.
dGrow (float): Threshold for identifying poorly modeled samples. A sample is considered poorly modeled if
its maximum firing level across all subnets is less than or equal to this value raised to the power of
the input dimensionality.
Nsplit (int): Minimum number of samples associated with a subnet for it to be considered for splitting.
eSplit (float): Minimum loss value of the samples associated with a subnet for it to be considered for splitting.
Nvanish (int): Maximum number of samples associated with a subnet below which its age counter is incremented.
lVanish (int): Maximum age of a subnet before it is removed.
max_iterations (int): Maximum number of structural adaptation iterations.
ANFIStrainer (base_model_trainer): Instantiated ANFIS training algorithm that defines how the model parameters
are updated at each iteration.
early_stopping (EarlyStopping): Early stopping mechanism to use during the SONFIS iterations. Defaults to ``None``.
lse_for_new_consequents (bool): If ``True``, the consequent parameters of rules generated by GrowNet or SplitSubNet
are initialized using least-squares estimation instead of random initialization. Defaults to ``False``.
lse_for_new_consequents_lambda (float): Lambda value for Ridge regularization in the least-squares initialization
of new consequent parameters. If ``0.``, no regularization is applied. Defaults to ``0.``.
last_training_iteration (bool): If ``True``, performs a final parameter update over all subnets after the SONFIS
algorithm finishes. Defaults to ``False``.
"""
# ------------- SONFIS -------------
# Hyperparameters
self.Ngrow = Ngrow
self.dGrow = dGrow
self.Nsplit = Nsplit
self.eSplit = eSplit
self.Nvanish = Nvanish
self.lVanish = lVanish
self.max_iterations = max_iterations
self.lse_for_new_consequents = lse_for_new_consequents
self.lse_for_new_consequents_lambda = lse_for_new_consequents_lambda
self.last_training_iteration = last_training_iteration
# Early stopping
self.sonfis_early_stopping = early_stopping
# history
self.history = {"loss": []}
self.val_history = {"loss": []}
# --------- ANFIS trainer ---------
self.trainer = ANFIStrainer
self.loss_function = self.trainer.loss_function
# Optimizer
self._optimizer_instance = None
# ------ Internal variables & methods ------
self._freezed = torch.tensor([], dtype=torch.int)
self._ages = torch.tensor([], dtype=torch.int)
self._get_max_firing_level = None
# ------ Rules Dataframe ------
self._rules_dataframe = None
self._current_best_rules_dataframe = None
self._current_max_idx = 0
self._best_rules_dataframe_iter = None
[docs]
def __call__(self, ANFISmodel, train_loader, val_loader=None, verbose=True):
"""
Runs the SONFIS algorithm.
At each iteration, the structural adaptation operators GrowNet,
SplitSubNet, and VanishNet are applied, followed by a parameter update
of the modified subnets. The algorithm stops when no structural updates
occur or the maximum number of iterations is reached. If early stopping
is enabled and a validation DataLoader is provided, the algorithm may
also stop early based on the validation loss.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to train.
train_loader (DataLoader): DataLoader containing the training data.
val_loader (DataLoader): DataLoader containing the validation data. Defaults to ``None``.
verbose (bool): If ``True``, prints progress and structural update messages at each iteration. Defaults to ``True``.
"""
if ANFISmodel._rule_reduced == False:
raise ValueError('The ANFIS model must be rule reduced')
self._set_max_firing_level_method(ANFISmodel)
self._current_max_idx = ANFISmodel.rules - 1
self._register_loss(ANFISmodel, train_loader, val_loader)
self._ages = torch.zeros(ANFISmodel.rules, dtype=torch.int)
self._freezed = torch.zeros(ANFISmodel.rules, dtype=torch.int).bool()
iter_width = len(str(self.max_iterations))
print(f'ITERATION: {0:{iter_width}}/{self.max_iterations}')
self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed)
if verbose:
early_stop_flag = False
self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True)
if val_loader is not None:
self._current_best_rules_dataframe = self._rules_dataframe.copy()
self._best_rules_dataframe_iter = 0
print("\nSTARTING STATE:")
print(self._rules_dataframe.to_string())
if val_loader is not None:
print(f'\n\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}')
else:
print(f'\n\tloss: {self.history["loss"][-1]:.6f}')
print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n')
model_updated = True
i = 0
while(model_updated and i < self.max_iterations):
print(f'\nITERATION: {i+1:{iter_width}}/{self.max_iterations}')
self._freeze_subnets()
model_updated = self._update_structure(ANFISmodel, train_loader, verbose)
if model_updated:
self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed)
if verbose:
self._replace_trained_subnets_on_rules_dataframe(ANFISmodel)
print("\nCURRENT STATE:")
print(self._rules_dataframe.to_string())
self._register_loss(ANFISmodel, train_loader, val_loader)
if val_loader is not None:
print(f'\n\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}')
else:
print(f'\n\tloss: {self.history["loss"][-1]:.6f}')
print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n')
if (val_loader is not None) and (self.sonfis_early_stopping is not None):
if self._check_early_stop(ANFISmodel, self.val_history["loss"][-1]):
print(f'found on the {self._best_rules_dataframe_iter}° iteration.')
early_stop_flag = True
break
else:
if self.sonfis_early_stopping._counter == 0:
self._current_best_rules_dataframe = self._rules_dataframe.copy()
self._best_rules_dataframe_iter = i+1
else:
print('NO MORE UPDATES')
i += 1
if i == self.max_iterations: print('MAX ITERATIONS REACHED')
self._unfreeze_subnets()
if self.last_training_iteration:
print('\nLast training iteration (all subnets are being trained again)')
self.trainer._sonfis_update_parameters(ANFISmodel, train_loader, val_loader, self._freezed)
if verbose:
if early_stop_flag:
self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True)
else:
self._replace_trained_subnets_on_rules_dataframe(ANFISmodel)
print("LAST TRAINING STATE:")
print(self._rules_dataframe.to_string())
self._register_loss(ANFISmodel, train_loader, val_loader)
if val_loader is not None:
print(f'\tloss: {self.history["loss"][-1]:.6f} - validation loss: {self.val_history["loss"][-1]:.6f}')
else:
print(f'\tloss: {self.history["loss"][-1]:.6f}')
print('\nTRAINING FINISHED')
print(f'\t --> ANFIS rules: {ANFISmodel.rules}\n')
if verbose:
if early_stop_flag:
self._rules_dataframe = ANFISmodel.get_rules_structure().reset_index(drop=True)
print(self._rules_dataframe.to_string())
# ----- Freezed subnets -----
def _freeze_subnets(self):
"""
Freezes all subnets, preventing their parameters from being updated during the next training step.
"""
self._freezed = torch.ones_like(self._freezed).bool()
def _unfreeze_subnets(self):
"""
Unfreezes all subnets, allowing their parameters to be updated during the next training step.
"""
self._freezed = torch.zeros_like(self._freezed).bool()
# ----- Early Stopping -----
def _check_early_stop(self, ANFISmodel, loss):
"""
Checks whether the SONFIS early stopping criterion is met.
If early stopping is triggered, the subnet ages and frozen states are
reset and the early stopping mechanism is reinitialized.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained.
loss (float): Current validation loss value.
Returns:
bool: ``True`` if training should be stopped, ``False`` otherwise.
"""
if self.sonfis_early_stopping is not None:
self.sonfis_early_stopping(ANFISmodel, loss, verbose=True)
if self.sonfis_early_stopping.stop:
self._ages = torch.zeros(ANFISmodel.rules, dtype=torch.int)
self._freezed = torch.zeros(ANFISmodel.rules, dtype=torch.int).bool()
self.sonfis_early_stopping.reset()
return True
return False
# ----- Rules dataframe -----
def _add_subnets_on_rules_dataframe(self, new_premises, new_consequents):
"""
Appends rows for newly added subnets to the internal rules DataFrame.
Args:
new_premises (torch.Tensor): Premise parameters of the new subnets, of shape ``(input_size, n_new_subnets, mf_params)``.
new_consequents (torch.Tensor): Consequent parameters of the new subnets, of shape ``(outputs, n_new_subnets, input_size + 1)``.
"""
n_new_subnets = new_premises.shape[1]
new_premises = new_premises.permute(1, 0, 2).reshape(n_new_subnets, -1)
new_consequents = new_consequents.permute(1, 0, 2).reshape(n_new_subnets, -1)
data_block = torch.cat([new_premises, new_consequents], dim=1).cpu().numpy()
new_subnets_df = pd.DataFrame(data_block, columns=self._rules_dataframe.columns)
start = self._current_max_idx + 1
new_subnets_df.index = [i for i in range(start, start + n_new_subnets)]
self._current_max_idx = new_subnets_df.index[-1]
if self._rules_dataframe.empty:
self._rules_dataframe = new_subnets_df
else:
self._rules_dataframe = pd.concat([self._rules_dataframe, new_subnets_df], axis=0)
def _drop_subnets_on_rules_dataframe(self, mask):
"""
Removes rows corresponding to subnets indicated by a boolean mask
from the internal rules DataFrame.
Args:
mask (torch.Tensor): Boolean tensor of length ``num_rules``, where ``True`` indicates that the corresponding
subnet should be removed.
"""
keep_mask = ~mask.cpu().numpy()
self._rules_dataframe = self._rules_dataframe.loc[keep_mask]
def _drop_subnets_on_rules_dataframe_by_idxs(self, idxs_tensor):
"""
Removes rows corresponding to subnets at the specified indices from
the internal rules DataFrame.
Args:
idxs_tensor (torch.Tensor): Tensor of integer indices indicating
which subnets should be removed.
"""
idxs = idxs_tensor.cpu().numpy()
keep = np.ones(len(self._rules_dataframe), dtype=bool)
keep[idxs] = False
self._rules_dataframe = self._rules_dataframe.iloc[keep]
def _replace_trained_subnets_on_rules_dataframe(self, ANFISmodel):
"""
Updates the rows of the internal rules DataFrame corresponding to
subnets that were not frozen during the last training step, reflecting
their current parameter values.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model whose
updated subnet parameters are used.
"""
mask = ~self._freezed
mask_np = mask.cpu().numpy().astype(bool)
new_premises = ANFISmodel.get_premises()[:, mask, :]
new_consequents = ANFISmodel.get_consequents()[:, mask, :]
n_new_subnets = new_premises.shape[1]
if n_new_subnets == 0: return
new_premises = new_premises.permute(1, 0, 2).reshape(n_new_subnets, -1)
new_consequents = new_consequents.permute(1, 0, 2).reshape(n_new_subnets, -1)
data_block = torch.cat([new_premises, new_consequents], dim=1).cpu().numpy()
self._rules_dataframe.iloc[mask_np, :] = data_block
# ----- Internal Methods -----
def _set_max_firing_level_method(self, ANFISmodel):
"""
Assigns the method used to retrieve the maximum firing level per sample,
depending on whether the model uses a default rule.
If the model has a default rule, the last firing level (corresponding
to the default rule) is excluded from the maximum computation.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained.
"""
if ANFISmodel._default_rule:
self._get_max_firing_level = lambda firing_levels: torch.max(firing_levels[:, :-1], dim=1)
else:
self._get_max_firing_level = lambda firing_levels: torch.max(firing_levels, dim=1)
# ----- Update structure -----
def _update_structure(self, ANFISmodel, train_loader, verbose):
"""
Executes the GrowNet, SplitSubNet, and VanishNet structural adaptation
operators in sequence.
GrowNet is attempted first. If no new subnets are grown, SplitSubNet
is attempted. VanishNet is always applied regardless of the outcomes
of the previous operators.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update.
train_loader (DataLoader): DataLoader containing the training data.
verbose (bool): If ``True``, prints messages about structural changes.
Returns:
bool: ``True`` if any structural update was performed, ``False`` otherwise.
"""
did_Grow = self._GrowNet(ANFISmodel, train_loader, verbose)
did_Split = False
if not did_Grow:
did_Split = self._SplitSubNet(ANFISmodel, train_loader, verbose)
did_Vanish = self._VanishNet(ANFISmodel, train_loader, verbose)
return did_Grow or did_Split or did_Vanish
def _GrowNet(self, ANFISmodel, train_loader, verbose):
"""
Executes the GrowNet operator to add new subnets to the model.
Identifies poorly modeled samples — those whose maximum firing level
across all subnets is below ``dGrow ** input_size`` — and groups them
by their associated subnet. For each group with more than ``Ngrow``
samples, a new subnet is created centered on the mean of those samples.
If ``lse_for_new_consequents`` is enabled, the consequent parameters
of the new subnets are estimated using least squares.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update.
train_loader (DataLoader): DataLoader containing the training data.
verbose (bool): If ``True``, prints messages about the new subnets added.
Returns:
bool: ``True`` if at least one new subnet was added, ``False`` otherwise.
"""
X = train_loader.dataset.tensors[0]
y = train_loader.dataset.tensors[1]
firing_levels = ANFISmodel.get_firing_levels(X)
max_fl = self._get_max_firing_level(firing_levels)
dGrowMask = max_fl.values <= self.dGrow**ANFISmodel._input_size # using dGrow
bad_samples = X[dGrowMask]
bad_targets = y[dGrowMask]
clusters = max_fl.indices[dGrowMask] # clusters are the associated "bad samples" to each subnet based on the max firing level
unique_rules, counts = torch.unique(clusters, return_counts=True) # how many "max firing levels" do each of the subnets get considering only the "bad_samples"?
Ngrow_mask = counts > self.Ngrow
indices_to_keep = torch.isin(clusters, unique_rules[Ngrow_mask]).nonzero().squeeze() # using Ngrow
bad_samples = bad_samples[indices_to_keep] # getting which samples will be considered
bad_targets = bad_targets[indices_to_keep]
clusters = clusters[indices_to_keep] # & the associated subnet
if bad_samples.size(0) == 0:
return False
else:
rules = [clusters == rule for rule in torch.unique(clusters)] # list of boolean masks (lenght: current number of subnets), each one with shape: (bad_samples.shape[0], )
means = torch.stack([bad_samples[rule].mean(dim=0) for rule in rules]) # shape = (new_subnets, input_dim)
stds = torch.stack([bad_samples[rule].std(dim=0) for rule in rules]) # shape = (new_subnets, input_dim)
new_premises = ANFISmodel._fuzzification_layer._membership_function._grow_new_premise_parameters(means, stds)
ANFISmodel.set_premises(torch.cat((ANFISmodel.get_premises(), new_premises), dim=1))
n_new_rules = new_premises.shape[1]
new_consequents = ANFISmodel._consequent_layer._consequent_function.random_consequents(ANFISmodel._outputs, n_new_rules, ANFISmodel._input_size, ANFISmodel._dtype)
ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents(), new_consequents), dim=1))
if self.lse_for_new_consequents: # if True, consequents are init with LSE
new_consequents = self._lse_after_GrowNet(ANFISmodel, bad_samples, bad_targets, rules, n_new_rules)
ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents()[:, :-n_new_rules, :], new_consequents), dim=1))
if verbose:
self._add_subnets_on_rules_dataframe(new_premises, new_consequents)
print(f"\t-> Growing {n_new_rules} new subnets: {[i for i in range(self._current_max_idx - new_premises.shape[1] + 1, self._current_max_idx + 1)]}")
self._ages = torch.cat((self._ages, torch.zeros(new_premises.shape[1], dtype=torch.int)))
self._freezed = torch.cat((self._freezed, torch.zeros(new_premises.shape[1], dtype=torch.int).bool()))
return True
def _lse_after_GrowNet(self, ANFISmodel, samples, targets, rules_mask, n_new_rules):
"""
Estimates the consequent parameters of the subnets added by GrowNet
using least-squares estimation.
Each new subnet's consequents are estimated independently using only
the samples associated with that subnet. Ridge regularization is applied
if ``lse_for_new_consequents_lambda`` is greater than ``0.``.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained.
samples (torch.Tensor): Input samples considered for the creation of the new subnets, of shape ``(n_samples, input_size)``.
targets (torch.Tensor): Targets associated with the input samples, of shape ``(n_samples,)`` or ``(n_samples, outputs)``.
rules_mask (list[torch.Tensor]): List of boolean masks of length ``n_new_rules``, each indicating which samples are associated with the corresponding new subnet.
n_new_rules (int): Number of new subnets added by GrowNet.
Returns:
torch.Tensor: Estimated consequent parameters for the new subnets, of shape ``(outputs, n_new_rules, input_size + 1)``.
"""
new_consequents = torch.tensor([])
i = 0
for rule in rules_mask:
x = samples[rule]
y = targets[rule]
w_norm = ANFISmodel.get_firing_levels(x, normalized=True)
xe = torch.cat([x, torch.ones(x.shape[0], 1)], dim=1)
fs = w_norm[:, i - n_new_rules].unsqueeze(0).t()
'''preliminary fix for the dtype issue'''
if ANFISmodel._output_type == 'softmax':
y = y.to(torch.int64)
y = torch.nn.functional.one_hot(y, ANFISmodel._outputs)
if y.dtype != xe.dtype:
y = y.to(xe.dtype)
'''preliminary fix for the dtype issue'''
A = xe * fs
if self.lse_for_new_consequents_lambda > 0.:
p = A.shape[1]
I = torch.eye(p, dtype=A.dtype) * torch.sqrt(torch.tensor(self.lse_for_new_consequents_lambda, dtype=A.dtype))
A = torch.cat([A, I], dim=0)
if y.dim() > 1:
m = y.shape[1]
zeros = torch.zeros((p, m), dtype=A.dtype)
else:
zeros = torch.zeros(p, dtype=A.dtype)
y = torch.cat([y, zeros], dim=0)
C, _, _, _ = torch.linalg.lstsq(A, y, rcond=None, driver="gelsd")
new_consequents = torch.cat((new_consequents, C.t().reshape(ANFISmodel._outputs, 1, xe.shape[1])), dim=1)
i += 1
return new_consequents
def _SplitSubNet(self, ANFISmodel, train_loader, verbose):
"""
Executes the SplitSubNet operator to split subnets with high local
error.
Identifies subnets that have more than ``Nsplit`` associated samples
and a loss value above ``eSplit``, and splits each of them into two
new subnets. If ``lse_for_new_consequents`` is enabled, the consequent
parameters of the resulting subnets are estimated using least squares.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update.
train_loader (DataLoader): DataLoader containing the training data.
verbose (bool): If ``True``, prints messages about the subnets split.
Returns:
bool: ``True`` if at least one subnet was split, ``False`` otherwise.
"""
samples = train_loader.dataset.tensors[0]
targets = train_loader.dataset.tensors[1]
firing_levels = ANFISmodel.get_firing_levels(samples)
max_fl = self._get_max_firing_level(firing_levels)
best_rules = max_fl.indices
with torch.no_grad():
model_outputs = ANFISmodel(samples, return_probs=False) # get model preds
unique_rules, counts = torch.unique(best_rules, return_counts=True) # how many "max firing levels" do each of the subnets get?
Nsplit_mask = counts > self.Nsplit
indices_to_keep = torch.isin(best_rules, unique_rules[Nsplit_mask]).nonzero().squeeze() # using Nsplit
if indices_to_keep.size(0) == 0:
return False
else:
# collect the samples, targets, outputs and the "best rule associated" (based on the max firing level) to be considered
model_outputs = model_outputs[indices_to_keep]
samples = samples[indices_to_keep]
targets = targets[indices_to_keep]
best_rules = best_rules[indices_to_keep]
unique_rules = torch.unique(best_rules)
if targets.dtype != train_loader.dataset.tensors[1].dtype:
targets = targets.to(train_loader.dataset.tensors[1].dtype)
rules = [best_rules == rule for rule in unique_rules] # list of boolean masks (lenght: current number of subnets), each one with shape: (bad_samples.shape[0], )
model_outputs_list = [model_outputs[rule] for rule in rules]
targets_list = [targets[rule] for rule in rules]
# compute loss
if ANFISmodel._output_type == "softmax" and ANFISmodel._custom_classes: # if classes are not [0, 1, 2, ...]
loss_values = torch.stack([self.loss_function(model_outputs_list[i], torch.searchsorted(ANFISmodel.classes, targets_list[i]).long()) for i in range(len(rules))]) # for each of the considered subnets with ONLY its associated samples
else:
loss_values = torch.stack([self.loss_function(model_outputs_list[i], targets_list[i]) for i in range(len(rules))]) # for each of the considered subnets with ONLY its associated samples
eSplit_mask = loss_values > self.eSplit
rules_to_split = unique_rules[eSplit_mask]
if ((targets.shape[0] == 0) or (rules_to_split.shape[0] == 0)): # using eSplit
return False
else:
if self.lse_for_new_consequents:
to_split_samples_list, to_split_targets_list = self._group_samples_for_lse_in_order(eSplit_mask, targets_list, samples, rules)
new_premises = ANFISmodel.get_premises()
new_consequents = ANFISmodel.get_consequents()
all_new_premises = torch.tensor([])
all_new_consequents = torch.tensor([])
idx = 0
for rule in list(torch.flip(rules_to_split, [0]).long()): # using eSplit
new_premises = torch.cat((new_premises[:, :rule,:], new_premises[:, rule+1:, :]), dim=1)
to_split = ANFISmodel.get_premises()[:, rule:rule+1, :]
split = ANFISmodel._fuzzification_layer._membership_function._split_premise_parameters(to_split)
new_premises = torch.cat((new_premises, split), dim=1)
new_consequents = torch.cat((new_consequents[:, :rule, :], new_consequents[:, rule+1:, :]), dim=1)
new_consequent_to_add = ANFISmodel._consequent_layer._consequent_function.random_consequents(ANFISmodel._outputs, 2, ANFISmodel._input_size, ANFISmodel._dtype)
new_consequents = torch.cat((new_consequents, new_consequent_to_add), dim=1)
self._ages = torch.cat((self._ages[:rule], self._ages[rule+1:]))
self._freezed = torch.cat((self._freezed[:rule], self._freezed[rule+1:]))
self._ages = torch.cat((self._ages, torch.zeros(2, dtype=torch.int)))
self._freezed = torch.cat((self._freezed, torch.zeros(2, dtype=torch.int).bool()))
if verbose:
all_new_premises = torch.cat((all_new_premises, split), dim=1)
all_new_consequents = torch.cat((all_new_consequents, new_consequent_to_add), dim=1)
ANFISmodel.set_premises(new_premises)
ANFISmodel.set_consequents(new_consequents)
if self.lse_for_new_consequents:
new_2_last_consequents_to_replace = self._lse_while_SplitSubNet(ANFISmodel, to_split_samples_list[idx], to_split_targets_list[idx])
ANFISmodel.set_consequents(torch.cat((ANFISmodel.get_consequents()[:, :-2, :], new_2_last_consequents_to_replace), dim=1))
new_premises = ANFISmodel.get_premises()
new_consequents = ANFISmodel.get_consequents()
idx += 1
if verbose:
subnets = rules_to_split.tolist()
if rules_to_split[rules_to_split == True].size(0) == 1:
subnets = [subnets]
#if isinstance(self._rules_dataframe.index[subnets[0]], int):
# #print(f"\t-> self._rules_dataframe.index[i]: {self._rules_dataframe.index[subnets[0]]}")
# print(f"\t-> Splitting {rules_to_split.shape[subnets[0]]} subnets: {[self._rules_dataframe.index[i] for i in subnets]}")
#else:
# print(f"\t-> Splitting {rules_to_split.shape[0]} subnets: {[self._rules_dataframe.index[i].item() for i in subnets]}")
print(f"\t-> Splitting {rules_to_split.shape[0]} subnets: {subnets}")
self._drop_subnets_on_rules_dataframe_by_idxs(rules_to_split)
self._add_subnets_on_rules_dataframe(all_new_premises, all_new_consequents)
return True
def _group_samples_for_lse_in_order(self, eSplit_mask, targets_list, samples, rules):
"""
Groups samples and targets in reverse order for least-squares
estimation during the SplitSubNet operation.
Processes subnets in reverse order to match the order in which they
are split, ensuring consistency between the sample groups and the
subnets being updated.
Args:
eSplit_mask (torch.Tensor): Boolean tensor indicating which subnets exceed the ``eSplit`` threshold and are to be split.
targets_list (list[torch.Tensor]): List of target tensors, one per subnet considered for splitting.
samples (torch.Tensor): Input samples of the subnets considered for splitting, of shape ``(n_samples, input_size)``.
rules (list[torch.Tensor]): List of boolean masks indicating which samples are associated with each subnet.
Returns:
tuple[list[torch.Tensor], list[torch.Tensor]]: A tuple containing two lists (the grouped input samples and their corresponding targets) ordered to match the reverse splitting sequence.
"""
to_split_samples_list = []
to_split_targets_list = []
i = eSplit_mask.shape[0] - 1
for boolean in torch.flip(eSplit_mask, [0]):
if boolean:
to_split_targets_list.append(targets_list[i])
to_split_samples_list.append(samples[rules[i]])
i -= 1
return to_split_samples_list, to_split_targets_list
def _lse_while_SplitSubNet(self, ANFISmodel, samples, targets):
"""
Estimates the consequent parameters of the two subnets resulting from
a SplitSubNet operation using least-squares estimation.
The estimation uses only the samples associated with the subnet that
was split. Ridge regularization is applied if
``lse_for_new_consequents_lambda`` is greater than ``0.``.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model being trained.
samples (torch.Tensor): Input samples associated with the subnet that was split, of shape ``(n_samples, input_size)``.
targets (torch.Tensor): Targets associated with those samples, of shape ``(n_samples,)`` or ``(n_samples, outputs)``.
Returns:
torch.Tensor: Estimated consequent parameters for the two new subnets, of shape ``(outputs, 2, input_size + 1)``.
"""
new_consequents = torch.tensor([])
x = samples
y = targets
w_norm = ANFISmodel.get_firing_levels(x, normalized=True)
x = torch.cat([x, torch.ones(x.shape[0], 1)], dim=1)
w_norm = w_norm[:, -2:].unsqueeze(2).repeat(1, 1, x.shape[1]).view(w_norm[:, -2:].shape[0], -1)
X = x.repeat(1, 2)
'''preliminary fix for the dtype issue'''
if ANFISmodel._output_type == 'softmax':
y = y.to(torch.int64)
y = torch.nn.functional.one_hot(y, ANFISmodel._outputs)
if y.dtype != X.dtype:
y = y.to(X.dtype)
'''preliminary fix for the dtype issue'''
A = X * w_norm
if self.lse_for_new_consequents_lambda > 0.:
p = A.shape[1]
I = torch.eye(p, dtype=A.dtype) * torch.sqrt(torch.tensor(self.lse_for_new_consequents_lambda, dtype=A.dtype))
A = torch.cat([A, I], dim=0)
if y.dim() > 1:
m = y.shape[1]
zeros = torch.zeros((p, m), dtype=A.dtype)
else:
zeros = torch.zeros(p, dtype=A.dtype)
y = torch.cat([y, zeros], dim=0)
C, _, _, _ = torch.linalg.lstsq(A, y, rcond=None, driver="gelsd")
new_consequents = C.t().reshape(ANFISmodel._outputs, 2, x.shape[1])
return new_consequents
def _VanishNet(self, ANFISmodel, train_loader, verbose):
"""
Executes the VanishNet operator to remove underused subnets.
Tracks the age of each subnet by incrementing its counter whenever
fewer than ``Nvanish`` samples are associated with it, and resetting
it otherwise. Subnets whose age reaches or exceeds ``lVanish`` and
that still have fewer than ``Nvanish`` associated samples are removed
from the model.
Args:
ANFISmodel (rule_reduced_ANFIS): Rule-reduced ANFIS model to update.
train_loader (DataLoader): DataLoader containing the training data.
verbose (bool): If ``True``, prints messages about the subnets removed.
Returns:
bool: ``True`` if at least one subnet was removed, ``False`` otherwise.
"""
X = train_loader.dataset.tensors[0]
firing_levels = ANFISmodel.get_firing_levels(X)
max_fl = self._get_max_firing_level(firing_levels)
best_rules = max_fl.indices
unique_rules, counts = torch.unique(best_rules, return_counts=True) # how many "max firing levels" do each of the subnets get?
all_rules = torch.arange(ANFISmodel.rules)
total_counts = torch.zeros(ANFISmodel.rules, dtype=torch.int64)
total_counts[unique_rules] = counts
self._ages[all_rules[(total_counts < self.Nvanish)]] += 1 # add 1 age to the subnets that have less than Nvanish "associated" samples
self._ages[all_rules[(total_counts >= self.Nvanish)]] = 0 # reset age to 0 for the subnets that have more than Nvanish "associated" samples
mask = ((self._ages >= self.lVanish) & (total_counts < self.Nvanish)) # using Nvanish & lVanish --> to filter by age 6 by number of "associated" samples
rules_to_eliminate = all_rules[mask]
if torch.equal(rules_to_eliminate, torch.tensor([], dtype=torch.int64)): # if there ARE NOT rules to eliminate
return False
else: # if there ARE rules to eliminate
new_premises = ANFISmodel.get_premises()
new_consequents = ANFISmodel.get_consequents()
for rule in torch.flip(rules_to_eliminate, dims=(0,)):
new_premises = torch.cat((new_premises[:, :rule, :], new_premises[:, rule+1:, :]), dim=1)
new_consequents = torch.cat((new_consequents[:, :rule, :], new_consequents[:, rule+1:, :]), dim=1)
self._ages = torch.cat((self._ages[:rule], self._ages[rule+1:]))
self._freezed = torch.cat((self._freezed[:rule], self._freezed[rule+1:]))
ANFISmodel.set_premises(new_premises)
ANFISmodel.set_consequents(new_consequents)
if verbose:
subnets = (mask.nonzero().squeeze()).tolist()
if mask[mask == True].size(0) == 1:
subnets = [subnets]
print(f"\t-> Vanishing {rules_to_eliminate.size(0)} subnets: {[self._rules_dataframe.index[i].item() for i in subnets]}")
self._drop_subnets_on_rules_dataframe(mask)
return True