WHCSRL 技术网

tensorflow2.0使用自带的函数求精准率和召回率(解决Shapes (None, 10) and (None, 1) are incompatible)

本代码使用的是cifar10数据集,所以有十个类别
废话不多说,直接给代码吧

import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers,metrics
(x_train, y_train), _ = datasets.cifar10.load_data()

def procession(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.squeeze(y)
    y = tf.one_hot(y, depth=10)

    return x, y
model = Sequential([
    layers.Flatten(input_shape=(32, 32, 3)),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax')
])
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).map(procession).batch(128)
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])
model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Recall()])
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Precision()])
model.fit(train_db, epochs=5)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

出现Shapes (None, 10) and (None, 1) are incompatible的原因是:
x通过模型之后会得到一个shape为(None,10)的数据, 而y因为没有进行one_hot编码,y.shape=(None, 1),形状不同所以不能进行计算
没有对标签y_train进行one_hot编码,但是单单进行one_hot编码也是不够的,因为进行one_hot编码之后y_train.shape = (None, 1, 10)就会报错Shapes (None, 10) and (None, 1,10) are incompatible, 所以在对y_train进行处理时,通过tf.squeeze(y_train)是的y_train.shape = (None, 10),这样子就可以进行计算了。

文章知识点与官方知识档案匹配,可进一步学习相关知识
推荐阅读