# -*- coding: utf-8 -*- """ # Simple pytorch iris classification (c) G. Turinici 2025 """ import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import numpy as np # Step 1: Load the Iris dataset iris = load_iris() X = iris.data.astype(np.float32) # Features y = iris.target.astype(np.int64) # Labels as integers print(f'X.shape={X.shape}, y.shape={y.shape}') # Step 2: Split into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Convert to PyTorch tensors X_train_tensor = torch.tensor(X_train) y_train_tensor = torch.tensor(y_train) X_test_tensor = torch.tensor(X_test) y_test_tensor = torch.tensor(y_test) # Step 3: Define the PyTorch model class IrisNet(nn.Module): def __init__(self): super(IrisNet, self).__init__() self.fc1 = nn.Linear(X.shape[1], 10) self.fc2 = nn.Linear(10, 3) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) # No softmax here (CrossEntropyLoss applies it internally) return x model = IrisNet() # Step 4: Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # Step 5: Training loop epochs = 100 batch_size = 16 for epoch in range(epochs): permutation = torch.randperm(X_train_tensor.size()[0]) epoch_loss = 0.0 for i in range(0, X_train_tensor.size()[0], batch_size): indices = permutation[i:i+batch_size] batch_X, batch_y = X_train_tensor[indices], y_train_tensor[indices] optimizer.zero_grad() outputs = model(batch_X) loss = criterion(outputs, batch_y) loss.backward() optimizer.step() epoch_loss += loss.item() if (epoch+1) % 10 == 0: print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}") # Step 6: Evaluation with torch.no_grad(): outputs = model(X_test_tensor) predicted_classes = torch.argmax(outputs, dim=1) accuracy = (predicted_classes == y_test_tensor).float().mean() print(f"Test Accuracy: {accuracy:.4f}") print(f"True Labels: {y_test_tensor.numpy()}") print(f"Predicted Labels: {predicted_classes.numpy()}")