import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# ==== MNIST データセットから1枚取得 ====
transform = transforms.Compose([
transforms.ToTensor(),
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
image, _ = dataset[0] # 1枚だけ
image = image.squeeze(0) # (1,28,28) → (28,28)
# ==== 拡散プロセス ====
T = 10 # ステップ数
noisy_images = []
x = image.clone()
for t in range(T):
noise = torch.randn_like(x) * 0.1
x = (x + noise).clamp(0,1) # ノイズを足す
noisy_images.append(x)
# ===== 擬似的な逆拡散(平均を取ってノイズを少しずつ減らすだけ) =====
denoised_images = []
y = noisy_images[-1].clone()
for t in range(T):
y = (y*0.9 + image*0.1) # 単純な補間で元画像に近づける
denoised_images.append(y)
# ===== 可視化 =====
fig, axes = plt.subplots(3, T, figsize=(15, 5))
# 元画像
axes[0,0].imshow(image, cmap="gray")
axes[0,0].set_title("Original")
axes[0,0].axis("off")
# 拡散 (ノイズ付与)
for i in range(T):
axes[1,i].imshow(noisy_images[i], cmap="gray")
axes[1,i].axis("off")
axes[1,0].set_title("Forward Diffusion")
# 逆拡散 (ノイズ除去の雰囲気だけ)
for i in range(T):
axes[2,i].imshow(denoised_images[i], cmap="gray")
axes[2,i].axis("off")
axes[2,0].set_title("Reverse (Toy)")
plt.show()