import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # ==== MNIST データセットから1枚取得 ==== transform = transforms.Compose([ transforms.ToTensor(), ]) dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform) image, _ = dataset[0] # 1枚だけ image = image.squeeze(0) # (1,28,28) → (28,28) # ==== 拡散プロセス ==== T = 10 # ステップ数 noisy_images = [] x = image.clone() for t in range(T): noise = torch.randn_like(x) * 0.1 x = (x + noise).clamp(0,1) # ノイズを足す noisy_images.append(x) # ===== 擬似的な逆拡散(平均を取ってノイズを少しずつ減らすだけ) ===== denoised_images = [] y = noisy_images[-1].clone() for t in range(T): y = (y*0.9 + image*0.1) # 単純な補間で元画像に近づける denoised_images.append(y) # ===== 可視化 ===== fig, axes = plt.subplots(3, T, figsize=(15, 5)) # 元画像 axes[0,0].imshow(image, cmap="gray") axes[0,0].set_title("Original") axes[0,0].axis("off") # 拡散 (ノイズ付与) for i in range(T): axes[1,i].imshow(noisy_images[i], cmap="gray") axes[1,i].axis("off") axes[1,0].set_title("Forward Diffusion") # 逆拡散 (ノイズ除去の雰囲気だけ) for i in range(T): axes[2,i].imshow(denoised_images[i], cmap="gray") axes[2,i].axis("off") axes[2,0].set_title("Reverse (Toy)") plt.show()

