[問題] 生成對抗訓練程式

作者: psw (ICK)   2023-08-29 09:54:24
如題
這個程式訓練一些照片
最後把訓練的鑑別網路權重參數結果存在TESTgen/discriminator_weights.h5中
但後來要載入TESTgen/discriminator_weights.h5這個參數鑑別網路時卻不斷說discrimi
nator_weights.h5
裡有問題
我打開discriminator_weights.h5中看起來是網路參數
跟float32浮點數格式
但要載入用來辨識其他照片時卻說無法載入HTF5格式
我用的是tensrflow GPU
求跪強者們開示
謝謝
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
#缺Keras HDF5 格式
# 設定圖像參數
img_rows = 28
img_cols = 28
channels = 1
# 設定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 設定鑑別器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels)))
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid'))
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels))
? ? validity = model(img)
? ? return Model(img, validity)
# 設定生成器和對抗器
generator = build_generator()
discriminator = build_discriminator()
# 編譯鑑別器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立結合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 載入並預處理MNIST資料集
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 定義訓練參數
epochs = 3000
batch_size = 128
save_interval = 100
# 定義圖像標籤
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# 訓練生成器和鑑別器
for epoch in range(epochs):
? ? # 訓練鑑別器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx]
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? gen_imgs = generator.predict(noise)
? ? d_loss_real = discriminator.train_on_batch(imgs, valid)
? ? d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
? ? d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
? ? # 訓練生成器
? ? noise = np.random.normal(0, 1, (batch_size, 100))
? ? g_loss = combined.train_on_batch(noise, valid)
? ? # 顯示訓練進度
? ? if epoch % save_interval == 0:
? ? ? ? print(f"Epoch {epoch}/{epochs}, D loss: {d_loss[0]}, acc.: {100 * d_lo
ss[1]}, G loss: {g_loss}")
? ? ? ? # 顯示生成的圖像
? ? ? ? r, c = 2, 2
? ? ? ? noise = np.random.normal(0, 1, (r * c, 100))
? ? ? ? gen_imgs = generator.predict(noise)
? ? ? ? fig, axs = plt.subplots(r, c)
? ? ? ? cnt = 0
? ? ? ? for i in range(r):
? ? ? ? ? ? for j in range(c):
? ? ? ? ? ? ? ? axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
? ? ? ? ? ? ? ? axs[i, j].axis('off')
? ? ? ? ? ? ? ? cnt += 1
? ? ? ? plt.show()
? ?
# 將生成網路和鑑別器的參數保存到TESTgen資料夾中
os.makedirs("TESTgen", exist_ok=True)
generator.save_weights("TESTgen/generator_weights.h5")
discriminator.save_weights("TESTgen/discriminator_weights.h5", save_format="h5
")
with open("TESTgen.txt", "w") as f:
? ? f.write("Generator and discriminator parameters saved.")
print("訓練完成並保存生成網路和鑑別器參數。") ? ?
? ?
?
? ? import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 匯入所需的庫和模組
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 設定圖像參數
img_rows = 28
img_cols = 28
channels = 1
# 設定生成器
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh'))
? ? model.add(Reshape((img_rows, img_cols, channels)))
? ? model.summary()
? ? noise = Input(shape=noise_shape)
? ? img = model(noise)
? ? return Model(noise, img)
# 建立生成器模型
def build_generator():
? ? noise_shape = (100,)
? ? model = Sequential()
? ? model.add(Dense(256, input_shape=noise_shape)) ?# 全連接層,輸入是噪音
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函數
? ? model.add(BatchNormalization(momentum=0.8)) ?# BatchNormalization 正規化
? ? model.add(Dense(512))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(1024))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(BatchNormalization(momentum=0.8))
? ? model.add(Dense(img_rows * img_cols * channels, activation='tanh')) ?# 生
成器輸出,使用 tanh 激活函數
? ? model.add(Reshape((img_rows, img_cols, channels))) ?# 重塑輸出形狀
? ? model.summary()
? ? noise = Input(shape=noise_shape) ?# 噪音輸入
? ? img = model(noise) ?# 使用模型生成圖像
? ? return Model(noise, img) ?# 返回噪音和生成圖像模型
# 設定鑑別器
def build_discriminator():
? ? model = Sequential()
? ? model.add(Flatten(input_shape=(img_rows, img_cols, channels))) ?# 將圖像展
平為一維
? ? model.add(Dense((img_rows * img_cols * channels), input_shape=(img_rows, i
mg_cols, channels))) ?# 全連接層
? ? model.add(LeakyReLU(alpha=0.2)) ?# LeakyReLU 激活函數
? ? model.add(Dense(int((img_rows * img_cols * channels) / 2)))
? ? model.add(LeakyReLU(alpha=0.2))
? ? model.add(Dense(1, activation='sigmoid')) ?# 預測真假的輸出,使用 sigmoid
激活函數
? ? model.summary()
? ? img = Input(shape=(img_rows, img_cols, channels)) ?# 圖像輸入
? ? validity = model(img) ?# 使用模型判斷真假
? ? return Model(img, validity) ?# 返回圖像和判斷真假模型
# 建立生成器和鑑別器
generator = build_generator() ?# 創建生成器模型
discriminator = build_discriminator() ?# 創建鑑別器模型
# 編譯鑑別器
discriminator.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? ? ? ? optimizer=Adam(0.0002, 0.5),
? ? ? ? ? ? ? ? ? ? ? metrics=['accuracy'])
# 建立結合模型
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False ?# 在結合模型中,鑑別器權重凍結
validity = discriminator(img)
combined = Model(z, validity) ?# 創建結合模型,輸入噪音,輸出真假
combined.compile(loss='binary_crossentropy',
? ? ? ? ? ? ? ? 漑ptimizer=Adam(0.0002, 0.5))
# 載入並預處理MNIST資料集
(X_train, _), (_, _) = mnist.load_data() ?# 載入MNIST數據集
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 ?# 正規化數據到-1到1之

X_train = np.expand_dims(X_train, axis=3) ?# 增加一個維度(通道)
# 定義訓練參數
epochs = 3000 ?# 訓練迭代次數
batch_size = 128 ?# 批次大小
save_interval = 100 ?# 每隔多少個迭代保存模型
# 定義圖像標籤
valid = np.ones((batch_size, 1)) ?# 真實標籤
fake = np.zeros((batch_size, 1)) ?# 假標籤
# 訓練生成器和鑑別器
for epoch in range(epochs):
? ? # 訓練鑑別器
? ? idx = np.random.randint(0, X_train.shape[0], batch_size)
? ? imgs = X_train[idx] ?# 隨機選取真實圖像
? ? noise = np.random.normal(0, 1, (batch_size, 100)) ?#
?
(X_tr
X_tra

X_tra
作者: lycantrope (阿寬)   2023-08-29 10:52:00
問GPT,不經大腦複製貼上,也沒寫你是怎麼載入h5
作者: tsoahans (ㄎㄎ)   2023-08-29 15:04:00
save_weights對應load_weights model.save對load_model
作者: lycantrope (阿寬)   2023-08-29 15:18:00
同樓上,從model= build_discriminator()產生model後model.load_weights才對
作者: chang1248w (彩棠)   2023-09-14 22:58:00
超熱心ww

Links booklink

Contact Us: admin [ a t ] ucptt.com