Example 1: Multiclass Classification on Iris Dataset

This example demonstrates the standard toolbox workflow on the Iris dataset, a four-feature, three-class classification benchmark. An h_ANFIS model is trained using the Basic Optimizer Training Algorithm with early stopping, and the trained model is analyzed using RulesAnalyzer.

Imports and reproducibility

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import (
    confusion_matrix, f1_score, precision_score,
    recall_score, accuracy_score, classification_report
)

import torch
import torch.nn as nn
import torch.utils.data as data
import numpy as np
import random

import neuro_fuzzy_toolbox as nft

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

Data

The four features are scaled to [0, 1] using MinMaxScaler. The dataset is split into training (70%), validation (16%), and test (14%) sets using stratified sampling to preserve class proportions.

iris = load_iris()
X, y = iris.data, iris.target

x_train, x_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, stratify=y, random_state=SEED
)
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, test_size=0.2, stratify=y_train, random_state=SEED
)

scaler = MinMaxScaler(feature_range=(0, 1))

x_train = torch.tensor(scaler.fit_transform(x_train), dtype=torch.float32)
x_val   = torch.tensor(scaler.transform(x_val),       dtype=torch.float32)
x_test  = torch.tensor(scaler.transform(x_test),      dtype=torch.float32)

y_train = torch.tensor(y_train)
y_val   = torch.tensor(y_val)
y_test  = torch.tensor(y_test)

DataLoaders

generator = torch.Generator()
generator.manual_seed(SEED)

train_loader = data.DataLoader(
    data.TensorDataset(x_train, y_train),
    batch_size=8, shuffle=True, generator=generator
)
val_loader = data.DataLoader(
    data.TensorDataset(x_val, y_val),
    batch_size=8, shuffle=False
)

Model

An h_ANFIS model is instantiated with 3 MFs per input feature and a softmax output layer for three-class classification. Premise parameters are initialized from the training data distribution, and consequent parameters are estimated by regularized least squares prior to gradient-based training.

model = nft.h_ANFIS(
    input_size=4,
    num_mfs=3,
    outputs=3,
    output_type='softmax',
    features=['sepal length', 'sepal width', 'petal length', 'petal width']
)

model.init_premises(x_train)
model.init_consequents(x_train, y_train, ridge_lambda=0.1)

Learning algorithm

The model is trained with AdamW and early stopping monitoring the validation loss.

trainer = nft.Basic_optimizer_training_algorithm(
    epochs=500,
    loss_function=nn.CrossEntropyLoss(),
    optimizer=torch.optim.AdamW,
    optimizer_params={'lr': 1e-3, 'weight_decay': 1e-2},
    early_stopping=nft.EarlyStopping(patience=30, delta=1e-4)
)

trainer(model, train_loader, val_loader)

Evaluation

pred = model.predict(x_test)

acc        = accuracy_score(y_test, pred)
prec       = precision_score(y_test, pred, average='weighted', zero_division=0)
recall     = recall_score(y_test, pred, average='weighted', zero_division=0)
f1         = f1_score(y_test, pred, average='weighted', zero_division=0)
conf_matrix = confusion_matrix(y_test, pred)
class_rep  = classification_report(y_test, pred)

print("Accuracy:", acc)
print("Precision:", prec)
print("Recall:", recall)
print("F1 score:", f1, "\n")

print("Confusion Matrix:")
print(conf_matrix, "\n")

print("Classification Report:")
print(class_rep)
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
f1 score: 1.0

Confusion Matrix:
[[15  0  0]
 [ 0 15  0]
 [ 0  0 15]]

Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        15
           1       1.00      1.00      1.00        15
           2       1.00      1.00      1.00        15

    accuracy                           1.00        45
   macro avg       1.00      1.00      1.00        45
weighted avg       1.00      1.00      1.00        45

Rule structure analysis

Once the model is trained, the rule base can be inspected in tabular form and the learned MFs can be visualized per input feature.

print(model.get_rules_structure().to_string())

model.plot_premises(group_by_dim=True)

The RulesAnalyzer class provides rule-level contribution analysis for a specific input sample. The example below retrieves the top 3 rules ranked by their leave-one-rule-out impact on the predicted class probability:

analyzer = nft.RulesAnalyzer(model)

top_rules = analyzer.top_activated_rules(
    x_test[0:1], top_k=3, sort_by='leave_one_rule_out'
)

for class_label, df in top_rules.items():
    print(f"{class_label}:")
    print(df.to_string(), "\n")
class_0:
   rule_id  firing_level  rule_output  contribution  I_logit_margin_max  I_logit_margin_mean  I_prob
0       53  1.751464e-13    -0.144441 -2.529828e-14       -4.459777e-14        -2.978208e-14     0.0
1       62  2.913474e-21    -0.137805 -4.014923e-22       -1.052639e-21        -6.202652e-22     0.0
2       61  9.106845e-24    -0.114750 -1.045014e-24       -2.155477e-24        -1.098971e-24     0.0

class_1:
   rule_id  firing_level  rule_output  contribution  I_logit_margin_max  I_logit_margin_mean        I_prob
0       15  8.753379e-05     0.658258  5.761984e-05            0.000027             0.000172  1.317821e-07
1       17  1.669576e-06     1.864878  3.113556e-06            0.000005             0.000007  2.235174e-08
2       41  2.438206e-07     2.951781  7.197052e-07            0.000001             0.000001  4.656613e-09

class_2:
   rule_id  firing_level  rule_output  contribution  I_logit_margin_max  I_logit_margin_mean    I_prob
0       45      0.972177     3.168142      3.079994            5.205152             5.228931  0.623935
1       18      0.024140     2.764272      0.066728            0.128098             0.133860  0.001298
2       42      0.003525     2.528276      0.008913            0.013312             0.013376  0.000122
explanation = analyzer.explain_prediction(x_test[0:1], top_k=3, sort_by="leave_one_rule_out")
print(explanation)
======================================================================
PREDICTION EXPLANATION
======================================================================

Predicted class: 2
Predicted probability: 0.9908

Logits and probabilities:
  Class 0: logit=-2.2506, p=0.0044
  Class 1: logit=-2.1909, p=0.0047
  Class 2: logit=3.1558, p=0.9908

Explaining predicted class: 2

Top rules (sorted by change in predicted class probability when the rule is removed):
----------------------------------------------------------------------

Rule 45 | w=0.9722 | f(x)=3.1681 | contrib=+3.0800 | I_prob=+0.6239
  IF sepal length ∈ [0.57, 0.80] AND sepal width ∈ [0.30, 0.79] AND petal length ∈ [0.80, 1.17] AND petal width ∈ [0.84, 1.16] THEN f_0(x) = -0.591*sepal length - 0.548*sepal width - 0.578*petal length - 0.561*petal width - 0.577
                                                                                                                                    f_1(x) = -0.596*sepal length - 0.541*sepal width - 0.585*petal length - 0.533*petal width - 0.553
                                                                                                                                    f_2(x) = 0.822*sepal length + 0.640*sepal width + 0.854*petal length + 0.769*petal width + 0.907


Rule 18 | w=0.0241 | f(x)=2.7643 | contrib=+0.0667 | I_prob=+0.0013
  IF sepal length ∈ [-0.14, 0.32] AND sepal width ∈ [0.30, 0.79] AND petal length ∈ [0.80, 1.17] AND petal width ∈ [0.84, 1.16] THEN f_0(x) = -0.800*sepal length - 0.808*sepal width - 0.767*petal length - 0.744*petal width - 0.765
                                                                                                                                     f_1(x) = -0.676*sepal length - 0.588*sepal width - 0.663*petal length - 0.662*petal width - 0.647
                                                                                                                                     f_2(x) = 0.734*sepal length + 0.686*sepal width + 0.712*petal length + 0.701*petal width + 0.702


Rule 42 | w=0.0035 | f(x)=2.5283 | contrib=+0.0089 | I_prob=+0.0001
  IF sepal length ∈ [0.57, 0.80] AND sepal width ∈ [0.30, 0.79] AND petal length ∈ [0.32, 0.58] AND petal width ∈ [0.84, 1.16] THEN f_0(x) = -0.335*sepal length - 0.316*sepal width - 0.331*petal length - 0.318*petal width - 0.340
                                                                                                                                    f_1(x) = -0.430*sepal length - 0.037*sepal width - 0.390*petal length - 0.167*petal width - 0.518
                                                                                                                                    f_2(x) = 0.707*sepal length + 0.202*sepal width + 0.711*petal length + 0.513*petal width + 0.953

Note

For a complete description of the analysis methods available in RulesAnalyzer, see Rule Inspection and Analysis.