[問題] TensorFlow網路參數儲存問題

作者: Paudse (SICO)   2018-04-23 19:44:55
我想將NN的參數儲存下次繼續學習
但發現儲存時似乎發生問題
叫出的參數每次都一樣
我的程式結構如下
還請強者指教
謝謝
class DQ:
def __init__():
self.sess = tf.Session()
saver = tf.train.Saver()
self.sess.run(tf.global_variables_initializer())
with tf.Session() as sess:
if os.path.isfile("save_net.ckpt.index"):
saver.restore(sess, "save_net.ckpt")
print('File exists, loading previous data!')
else:
# save_path = self.saver.save(self.sess, "save_net.ckpt")
print('File does not exist, starting fresh')
def _build_net(self):
省略
def learn(self,save_step):
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess, "save_net.ckpt")
print('save parameters')
作者: jameszhan (123)   2018-04-24 08:25:00
用with的話 你要跟訓練在同一個with裡吧不然你存參數時候的session應該沒東西最後的with那行等於一個新的session 你初始化參數後就直接存 中間應該要有訓練過程還是你只是想問為何初始化後存的參數會一樣?
作者: Paudse (SICO)   2018-04-24 10:38:00
恩恩 對阿 初始化參數都會一樣 是為什麼呢 謝謝甚至我把之前處存的ckpt檔都刪了 跑出來的參數還是一樣我朋友後來說他會存很多個ckpt 可以設定幾個epoch存一次要restore最後一個ckpt才是最接近訓練最後的結果
作者: jameszhan (123)   2018-04-24 13:13:00
當然啊 你可以看一下tensorflow的文件saver(max=n) 可以設定要保留幾個檔案
作者: Paudse (SICO)   2018-04-24 13:31:00
我現在用model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)但還是都從最一開始的開始訓練 不知道是怎麼回事另外也已經改成saver=tf.train.Saver(max_to_keep=1)
作者: goldflower (金色小黃花)   2018-04-24 13:34:00
=1代表只存一個吧https://stackoverflow.com/questions/48324072/照這個做應該就好了
作者: jameszhan (123)   2018-04-24 13:48:00
參數初始化的部分可以看這個truncated_normal_initiali從最一開始的訓練或許是你本來就只有一開始才有存?直接去github看別人完整的code比較快 看人家怎麼用的
作者: chchan1111 (123)   2018-04-24 13:54:00
對了 你這個code是不是怪怪的 你一開始就有實體化session了 為何後面還要with tf.se....
作者: Paudse (SICO)   2018-04-24 13:54:00
感謝各位的建議 我後來發現 我原本把放在saver.restore
作者: chchan1111 (123)   2018-04-24 13:55:00
直接self.sess.run就可以了 不然你等於又實體化一個session
作者: Paudse (SICO)   2018-04-24 13:55:00
一個if判斷句裡面檢查有沒有之前存的ckpt 但一值失敗我後來把saver.restore拿出那個if結構外就可以了雖然不太懂為何會有這個問題 不過現在OK了 感謝大家!!大大們說的沒錯 我後來把with拿掉了

Links booklink

Contact Us: admin [ a t ] ucptt.com