In [None]:
# Imports
import torch
import torch.nn as nn
import torchvision

In [None]:
# Let's load the dataset
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

size = (64, 64)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(size), torchvision.transforms.ToTensor()])
train_dataset = list(torchvision.datasets.Flowers102("/tmp/flowers", "train", transform=transform, download=True))
test_dataset = list(torchvision.datasets.Flowers102("/tmp/flowers", "test", transform=transform, download=True))

train_images = torch.stack([img for img, _ in train_dataset], dim=0).to(device)
test_images = torch.stack([img for img, _ in test_dataset], dim=0).to(device)
train_labels = torch.tensor([label for _, label in train_dataset]).to(device)
test_labels = torch.tensor([label for _, label in test_dataset]).to(device)

# Let's make sure we only have two classes
train_images, train_labels = train_images[train_labels < 2], train_labels[train_labels < 2]
test_images, test_labels = test_images[test_labels < 2], test_labels[test_labels < 2]

In [None]:
def accuracy(pred: torch.Tensor, label: torch.Tensor) -> float:
    return ((pred > 0.5) == label).float().mean().item()


model = torch.nn.Linear(in_features=size[0] * size[1] * 3, out_features=1)
model = model.to(device)

loss_fn = nn.BCEWithLogitsLoss()
optim = torch.optim.SGD(model.parameters(), lr=2e-2)
num_epochs = 500

for epoch in range(num_epochs):
    pred = model(train_images.view(train_images.shape[0], -1))[..., 0]
    loss_val = loss_fn(pred, train_labels.float())

    optim.zero_grad()
    loss_val.backward()
    optim.step()

    if epoch % 25 == 0 or epoch == num_epochs - 1:
        print(f"{epoch =:5d}  loss = {loss_val.item():.2f}  accuracy(train) = {accuracy(pred, train_labels):.3f}")

    if epoch % 100 == 0 or epoch == num_epochs - 1:
        with torch.inference_mode():
            pred = model(test_images.view(test_images.shape[0], -1))[..., 0]
            print(f"   Accuracy (test): {accuracy(pred, test_labels):.3f}")