import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader transform = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # ==== VAE model ==== class VAE(nn.Module): def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20): super(VAE, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) self.fc2 = nn.Linear(latent_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, input_dim) def encode(self, x): h = torch.relu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = torch.relu(self.fc2(z)) return torch.sigmoid(self.fc3(h)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar # ===== model training ===== model = VAE() optimizer = optim.Adam(model.parameters(), lr=1e-3) def loss_function(recon_x, x, mu, logvar): BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD # ===== training loop ===== epochs = 5 for epoch in range(epochs): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.view(-1, 784) # Flatten the input optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}') # ===== new sample creating ===== model.eval() with torch.no_grad(): z = torch.randn(16, 20) # Generate random latent vectors samples = model.decode(z).view(-1, 1, 28, 28) # Decode to images
$ python3 vae.py
Epoch 1, Loss: 164.1793
Epoch 2, Loss: 121.6226
Epoch 3, Loss: 114.7562
Epoch 4, Loss: 111.7122
Epoch 5, Loss: 109.9405
Loss がエポックごとに下がっている → 学習が進んでいる証拠
VAE の Loss は「再構築誤差 (BCE) + KLダイバージェンス」なので、単純な分類モデルの Accuracy とは違って「小さくなるほど良い」という見方