Comments (2)
Can anyone check this? I've changed gan_train.py
to this. Code runs but, losses don't change. What did I miss?:
...
with tf.name_scope('optimizer'):
gan_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
dis_vars = [var for var in gan_vars if 'dis' in var.name]
gen_vars = [var for var in gan_vars if 'gen' in var.name]
# dis_step = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE, name='dis_rmsprop').minimize(dis_loss, var_list=dis_vars)
# gen_step = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE, name='gen_rmsprop').minimize(gen_loss, var_list=gen_vars)
with tf.name_scope('dis_clip'):
dis_clip = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in dis_vars]
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph('./model/gan-1412.meta')
saver.restore(sess,tf.train.latest_checkpoint('./model'), )
graph = tf.get_default_graph()
dis_step = graph.get_operation_by_name('dis/conv1/w/dis_rmsprop')
gen_step = graph.get_operation_by_name('gen/fc1/w/gen_rmsprop')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(TRAIN_ITERS):
....
from gan.
Sorry! Since I'm not familiar with all this stuff, I was confused and made a mistake! I think this is the simple and correct way:
with tf.name_scope('optimizer'):
gan_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
dis_vars = [var for var in gan_vars if 'dis' in var.name]
gen_vars = [var for var in gan_vars if 'gen' in var.name]
dis_step = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE, name='dis_rmsprop').minimize(dis_loss, var_list=dis_vars)
gen_step = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE, name='gen_rmsprop').minimize(gen_loss, var_list=gen_vars)
with tf.name_scope('dis_clip'):
dis_clip = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in dis_vars]
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
saver.restore(sess,'./model/gan-1412',)
graph = tf.get_default_graph()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(TRAIN_ITERS):
from gan.
Related Issues (7)
- TypeError: write() argument must be str, not bytes HOT 1
- OSError: [Errno 2] No such file or directory: './data' HOT 1
- ValueError: Tried to convert 'reduction_indices' to a tensor and failed. Error: Argument must be a dense tensor: range(0, 3) - got shape [3], but wanted []. HOT 1
- unsupported operand type(s) for *: 'NoneType' and 'int'
- TypeError: Expected list for 'axis' argument to 'squeeze' Op, not range(0,3). HOT 1
- Restore saved model/checkpoint
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gan.