import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# ==== Generator ====
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64, nc=1): # nz: 潜在変数次元, nc: チャネル数
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf*4, 7, 1, 0, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# ==== Discriminator ====
class Discriminator(nn.Module):
def __init__(self, nc=1, ndf=64):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*2, 1, 7, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1, 1)
# ==== データセットの準備 ====
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# ==== モデルの初期化 ====
device = "cuda" if torch.cuda.is_available() else "cpu"
netG = Generator().to(device)
netD = Discriminator().to(device)
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
# ==== 学習の実行 ====
nz = 100
nz = 100
for epoch in range(1): # デモなので1エポックだけ
for i, (real, _) in enumerate(dataloader):
real = real.to(device)
b_size = real.size(0)
label_real = torch.ones(b_size, 1, device=device)
label_fake = torch.zeros(b_size, 1, device=device)
# --- Discriminator 学習 ---
netD.zero_grad()
output_real = netD(real)
lossD_real = criterion(output_real, label_real)
noise = torch.randn(b_size, nz, 1, 1, device=device)
fake = netG(noise)
output_fake = netD(fake.detach())
lossD_fake = criterion(output_fake, label_fake)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
# --- Generator 学習 ---
netG.zero_grad()
output = netD(fake)
lossG = criterion(output, label_real)
lossG.backward()
optimizerG.step()
if i % 200 == 0:
print(f"Epoch[{epoch}] Step[{i}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}")
# ==== 生成画像の表示 ====
noise = torch.randn(64, nz, 1, 1, device=device)
fake = netG(noise).detach().cpu()
grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.show()
$ python3 dcgan.py
Epoch[0] Step[0] Loss_D: 1.2740 Loss_G: 0.8563
Epoch[0] Step[200] Loss_D: 0.9803 Loss_G: 1.6264
Epoch[0] Step[400] Loss_D: 0.6745 Loss_G: 1.1438
Epoch[0] Step[600] Loss_D: 0.5692 Loss_G: 1.5479
Epoch[0] Step[800] Loss_D: 0.5681 Loss_G: 1.5251