In [None]:
import numpy as np
import torch
from PIL import Image

im = Image.open("rose_crop.jpg")
im_small = im.resize((128, 128))
im_small

In [None]:


class PositionalEmbedding(torch.nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.freq = torch.exp(torch.arange(0, embed_dim, 2).float() / 2)

    def forward(self, x):
        # x.shape  B x 2
        x = x[..., None, :] * self.freq[..., None].to(x.device)
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1).view(*x.shape[:-2], -1)


# f, ax = plt.subplots(3, 3, figsize=(5, 5))

# pe = PositionalEmbedding(8)
# e = pe(torch.linspace(-1, 1, 100)[:, None])
# for i in range(8):
#     ax[i // 3, i % 3].plot(e[:, i].numpy())

In [None]:


class Rose(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = PositionalEmbedding(12)
        self.net = torch.nn.Sequential(
            torch.nn.Linear(24, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 3),
        )

    def forward(self, x):
        return self.net(self.enc(x))


rose_tensor = torch.as_tensor(np.array(im_small), dtype=torch.float32) / 255.0 - 0.5
position = torch.stack(torch.meshgrid(torch.linspace(-1, 1, 128), torch.linspace(-1, 1, 128)), dim=-1)


net = Rose()

rose_tensor = rose_tensor.cuda()
position = position.cuda()
net = net.cuda()

optim = torch.optim.Adam(net.parameters(), lr=1e-3)
for it in range(5000):
    optim.zero_grad()
    loss = abs(net(position) - rose_tensor).mean()
    if it % 100 == 0:
        print(float(loss))
    loss.backward()
    optim.step()

Image.fromarray(((net(position) + 0.5).clamp(0, 1) * 255).cpu().to(torch.uint8).numpy())

In [None]:
im_small

In [None]:
position_hires = torch.stack(torch.meshgrid(torch.linspace(-1, 1, 1024), torch.linspace(-1, 1, 1024)), dim=-1).cuda()

Image.fromarray(((net(position_hires) + 0.5).clamp(0, 1) * 255).cpu().to(torch.uint8).numpy())