import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # ==== Generator ==== class Generator(nn.Module): def __init__(self, nz=100, ngf=64, nc=1): # nz: 潜在変数次元, nc: チャネル数 super().__init__() self.main = nn.Sequential( nn.ConvTranspose2d(nz, ngf*4, 7, 1, 0, bias=False), nn.BatchNorm2d(ngf*4), nn.ReLU(True), nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf*2), nn.ReLU(True), nn.ConvTranspose2d(ngf*2, nc, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, x): return self.main(x) # ==== Discriminator ==== class Discriminator(nn.Module): def __init__(self, nc=1, ndf=64): super().__init__() self.main = nn.Sequential( nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf*2, 1, 7, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): return self.main(x).view(-1, 1) # ==== データセットの準備 ==== transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) # ==== モデルの初期化 ==== device = "cuda" if torch.cuda.is_available() else "cpu" netG = Generator().to(device) netD = Discriminator().to(device) criterion = nn.BCELoss() optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # ==== 学習の実行 ==== nz = 100 nz = 100 for epoch in range(1): # デモなので1エポックだけ for i, (real, _) in enumerate(dataloader): real = real.to(device) b_size = real.size(0) label_real = torch.ones(b_size, 1, device=device) label_fake = torch.zeros(b_size, 1, device=device) # --- Discriminator 学習 --- netD.zero_grad() output_real = netD(real) lossD_real = criterion(output_real, label_real) noise = torch.randn(b_size, nz, 1, 1, device=device) fake = netG(noise) output_fake = netD(fake.detach()) lossD_fake = criterion(output_fake, label_fake) lossD = lossD_real + lossD_fake lossD.backward() optimizerD.step() # --- Generator 学習 --- netG.zero_grad() output = netD(fake) lossG = criterion(output, label_real) lossG.backward() optimizerG.step() if i % 200 == 0: print(f"Epoch[{epoch}] Step[{i}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}") # ==== 生成画像の表示 ==== noise = torch.randn(64, nz, 1, 1, device=device) fake = netG(noise).detach().cpu() grid = torchvision.utils.make_grid(fake, padding=2, normalize=True) plt.imshow(grid.permute(1, 2, 0)) plt.show()
$ python3 dcgan.py
Epoch[0] Step[0] Loss_D: 1.2740 Loss_G: 0.8563
Epoch[0] Step[200] Loss_D: 0.9803 Loss_G: 1.6264
Epoch[0] Step[400] Loss_D: 0.6745 Loss_G: 1.1438
Epoch[0] Step[600] Loss_D: 0.5692 Loss_G: 1.5479
Epoch[0] Step[800] Loss_D: 0.5681 Loss_G: 1.5251