首页 文章

tensorflow tensorboard错误:您必须为占位符张量提供一个值

提问于
浏览
1

初学者并尝试在tensorflow prgm中使用tensorboard . 我在教程中看到我添加了tensorboard refs,但是我收到以下错误消息:

InvalidArgumentError(请参见上面的回溯):您必须使用dtype float [[Node:x = Placeholderdtype = DT_FLOAT,shape = [],_ device =“/ job:localhost / replica:0 /)为占位符张量'x'提供值任务:0 / CPU:0" ]]

错误似乎与我在训练循环中添加的这一行有关 . 没有这一行,程序不会抛出任何错误:

summary = sess.run(merged_summary_op, {x: x_train, y_prim: y_train})

谢谢,如果有人可以检查下面的代码并帮助:

# -*- coding: utf-8 -*-

import tensorflow as tf
sess = tf.Session()

# parms
a = tf.Variable([2.0], dtype=tf.float32, name="a")
x = tf.placeholder(tf.float32, name="x")
b = tf.Variable([1.0], dtype=tf.float32, name="b")

# model : y=ax+b
with tf.name_scope('Model'):
    y = tf.add ((tf.multiply(a, x)), b)

# info for TensorBoard 
writer = tf.summary.FileWriter("D:\\tmp\\tensorflow\\logs", sess.graph)

# loss fct - mean square error
with tf.name_scope('cost'):
    y_prim = tf.placeholder(tf.float32)
    cost = tf.reduce_sum(tf.square(y - y_prim))

# optimizer = gradientdescent
with tf.name_scope('GradDes'):
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = optimizer.minimize(cost)

# train datas 
x_train = [1, 2, 3, 4]
y_train = [5.2, 8.4, 11.1, 14.7]

# summary for Tensorboard
tf.summary.scalar("cost", cost)
merged_summary_op = tf.summary.merge_all()

# init vars
init = tf.global_variables_initializer()

# train loop
sess.run(init)  
for i in range (500):
    sess.run([train, cost], feed_dict={x: x_train, y_prim: y_train})
    summary = sess.run(merged_summary_op, {x: x_train, y_prim: y_train})
    a_found, b_found, curr_cost = sess.run([a, b, cost], feed_dict={x:x_train, y_prim: y_train}) 
    print("iteration :", i, "a: ", a_found, "b: ", b_found, "cost: ",curr_cost)

2 回答

  • 0

    不要在单独的 sess.run 中执行合并的摘要操作 . 试试这个:

    a_found, b_found, curr_cost, summary = sess.run([a, b, cost, merged_summary_op], feed_dict={x:x_train, y_prim: y_train})
    

    会话运行后,您需要调用FileWriter的 add_summary 方法:

    writer.add_summary(summary)
    writer.flush()
    
  • 1

    是的,我的代码存在同样的问题,我通过重置图表解决了这个问题,将其添加到图表定义的开头

    tf.reset_default_graph()

    我认为你正在做的一切正确

相关问题