In [None]:
import torch
import torchvision
from PIL import Image

# Create a vector of zeros of size 5
size = (128, 128)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(size), torchvision.transforms.ToTensor()])
train_dataset = list(torchvision.datasets.Flowers102("/tmp/flowers", "train", transform=transform, download=True))
test_dataset = list(torchvision.datasets.Flowers102("/tmp/flowers", "test", transform=transform, download=True))


def visualize_image(img: torch.Tensor) -> Image.Image:
    return Image.fromarray((img.permute(1, 2, 0) * 255).to(torch.uint8).numpy())


visualize_image(train_dataset[1][0])

In [None]:
import matplotlib.pyplot as plt

f, ax = plt.subplots(4, 10, figsize=(10, 5))
for i, (im, lbl) in enumerate(list(train_dataset)[:40]):
    ax[i // 10, i % 10].imshow(visualize_image(im))
    ax[i // 10, i % 10].set_title(lbl)
    ax[i // 10, i % 10].axis("off")

In [None]:
train_images = torch.stack([im for im, _ in train_dataset], dim=0)
train_label = torch.tensor([label for _, label in train_dataset])

In [None]:
train_images_01 = train_images[train_label <= 1]
train_label_01 = train_label[train_label <= 1]

model = torch.nn.Linear(3 * 128 * 128, 1)
loss = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0)

for epoch in range(100):
    # Compute the model output
    out = model(train_images_01.view(-1, 3 * 128 * 128))

    # Compute loss
    loss_val = loss(out.squeeze(), train_label_01.float())

    # Compute gradient and update weights
    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()
    print(f"{epoch=} {loss_val.item()=}")

In [None]:
test_images = torch.stack([im for im, _ in test_dataset], dim=0)
test_label = torch.tensor([label for _, label in test_dataset])

In [None]:
test_images_01 = test_images[test_label <= 1]
test_label_01 = test_label[test_label <= 1]

pred_test = model(test_images_01.view(-1, 3 * 128 * 128))
print(((pred_test[:, 0] > 0).int() == test_label_01).float().mean())