拡散モデル(diffusion)

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# ==== MNIST データセットから1枚取得 ====
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
image, _ = dataset[0]  # 1枚だけ
image = image.squeeze(0)  # (1,28,28) → (28,28)

# ==== 拡散プロセス ====
T = 10  # ステップ数
noisy_images = []
x = image.clone()
for t in range(T):
    noise = torch.randn_like(x) * 0.1
    x = (x + noise).clamp(0,1)  # ノイズを足す
    noisy_images.append(x)

# ===== 擬似的な逆拡散(平均を取ってノイズを少しずつ減らすだけ) =====
denoised_images = []
y = noisy_images[-1].clone()
for t in range(T):
    y = (y*0.9 + image*0.1)  # 単純な補間で元画像に近づける
    denoised_images.append(y)

# ===== 可視化 =====
fig, axes = plt.subplots(3, T, figsize=(15, 5))

# 元画像
axes[0,0].imshow(image, cmap="gray")
axes[0,0].set_title("Original")
axes[0,0].axis("off")

# 拡散 (ノイズ付与)
for i in range(T):
    axes[1,i].imshow(noisy_images[i], cmap="gray")
    axes[1,i].axis("off")
axes[1,0].set_title("Forward Diffusion")

# 逆拡散 (ノイズ除去の雰囲気だけ)
for i in range(T):
    axes[2,i].imshow(denoised_images[i], cmap="gray")
    axes[2,i].axis("off")
axes[2,0].set_title("Reverse (Toy)")

plt.show()