WHCSRL 技术网

单神经元预测猫

写这个的目的是熟悉前向传播,反向传播的过程,不调用框架实现。

这次来实现一个检测猫的模型
在这里插入图片描述

构建如下神经元网络:
在这里插入图片描述
这是神经网络最简单的模型,只有一个神经元。


数据的准备

数据以h5格式保存
建议下载HDFView,可视化h5文件
在这里插入图片描述

print ("train_set_x_orig shape: " + str(train_set_x_orig.shape))
print ("train_set_y shape: " + str(train_set_y.shape))
print ("test_set_x_orig shape: " + str(test_set_x_orig.shape))
print ("test_set_y shape: " + str(test_set_y.shape))

train_set_x_orig shape: (209, 64, 64, 3)
train_set_y shape: (1, 209)
test_set_x_orig shape: (50, 64, 64, 3)
test_set_y shape: (1, 50)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

训练集209个,测试集50个,要注意不要让程序学习了测试集


激活函数:

采用sigmoid
在这里插入图片描述

def sigmoid(z):
    s = 1 / (1 + np.exp(-z))    
    return s
  • 1
  • 2
  • 3

前向传播

(网上扒的图,我懒得画了)
在这里插入图片描述
在这里插入图片描述
如果是一个一个算,一层就要O(N)复杂度。
使用向量化来计算:
在这里插入图片描述
向量化后时间复杂度变成O(1)
在这里插入图片描述
实际上一行代码就可以了:
在这里插入图片描述
用for循环实现:
在这里插入图片描述
用向量化实现:
在这里插入图片描述

 m = X.shape[1]
    
    # 前向传播
    A = sigmoid(np.dot(w.T, X) + b)                             
    cost = -np.sum(Y*np.log(A) + (1-Y)*np.log(1-A)) / m  
  • 1
  • 2
  • 3
  • 4
  • 5

损失函数

在这里插入图片描述

cost = -np.sum(Y*np.log(A) + (1-Y)*np.log(1-A)) / m  
  • 1

反向传播

先回顾一下前面的激活函数跟前向传播我们做了什么:

使用传播函数可以用图片中的每一个像素与权重w和阈值b做运算,计算结果a表明该图片是否是猫(大于域值判定为猫)

使用激活函数sigmoid可以将a的值映射到0~1的区间,计算结果y表明该图片是否是猫

可以总结为以下:
  1)传播函数通过w、b计算出A,通过激活函数又计算出Y

2)反向传播函数通过Y计算出dw和db

3)使用dw和db,通过一种叫“梯度下降”的方法得到新的w和b

4)使用更新后的w和b重复前面的运算过程

需要注意的是,“梯度下降”是一种更新w和b的方法,不同模型可能会使用不同的方法来更新w和b,只是我们这里使用这个比较容易理解的方法而已。其次我们不讨论“反向传播”的原理,只是对实现过程做分析。

激活函数中我们讲过,使用激活函数的目的是将A的值转为0~1的区间值Y,方便与标签进行对比。对比的方法很直接,使用Y减去标签值即可:
dZ = Y - train_label
其中Y是激活函数的输出,代表着图片是否接近猫。train_label就是所有训练图片的标签值(只有两个值一个1代表是猫,一个0代表不是猫)。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
通过上面的实例演示,我们会发现“若图片真的是猫,则Y减去train_label的值会变为负数(因为train_label值为1),若图片不是猫,则Y减去train_label的值是不变的(因为train_label的值为0)”。

这部分我还没理解透

下面依据dZ的值我们来计算出dw和db。

dw = np.dot(train_img, dZ.T)/m   db = np.sum(dZ)/m

更新参数

前面我们通过传播函数和激活函数得到图片为猫的概率,又通过反向传播函数计算出了dw和db。有了dw和db,我们就可以更新权重w和偏置b。

w = w - learning_rate*dw

b = b - learning_rate*db

def propagate(w, b, X, Y):
    """
    参数:
    w -- 权重数组,维度是(12288, 1)
    b -- 偏置bias
    X -- 图片的特征数据,维度是 (12288, 209)
    Y -- 图片对应的标签,0或1,0是无猫,1是有猫,维度是(1,209)

    返回值:
    cost -- 成本
    dw -- w的梯度
    db -- b的梯度
    """
    
    m = X.shape[1]
    
    # 前向传播
    #(1,12288)*(12288,209)=(1,209)
    A = sigmoid(np.dot(w.T, X) + b)                             
    cost = -np.sum(Y*np.log(A) + (1-Y)*np.log(1-A)) / m  
    
    # 反向传播
    dZ = A - Y
    #(12288,209)*(209,1)=(12288,1)
    dw = np.dot(X,dZ.T) / m
    db = np.sum(dZ) / m
    
    # 将dw和db保存到字典里面
    grads = {"dw": dw,
             "db": db}
    
    return grads, cost
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

实验:
在这里插入图片描述
进行5000次学习,学习率为0.005
结果如下:

优化100次后成本是: 0.693147
优化200次后成本是: 0.584508
优化300次后成本是: 0.466949
优化400次后成本是: 0.376007
优化500次后成本是: 0.331463
优化600次后成本是: 0.303273
优化700次后成本是: 0.279880
优化800次后成本是: 0.260042
优化900次后成本是: 0.242941
优化1000次后成本是: 0.228004
优化1100次后成本是: 0.214820
优化1200次后成本是: 0.203078
优化1300次后成本是: 0.192544
优化1400次后成本是: 0.183033
优化1500次后成本是: 0.174399
优化1600次后成本是: 0.166521
优化1700次后成本是: 0.159305
优化1800次后成本是: 0.152667
优化1900次后成本是: 0.146542
优化2000次后成本是: 0.140872
优化2100次后成本是: 0.135608
优化2200次后成本是: 0.130708
优化2300次后成本是: 0.126137
优化2400次后成本是: 0.121861
优化2500次后成本是: 0.117855
优化2600次后成本是: 0.114093
优化2700次后成本是: 0.110554
优化2800次后成本是: 0.107219
优化2900次后成本是: 0.104072
优化3000次后成本是: 0.101097
优化3100次后成本是: 0.098280
优化3200次后成本是: 0.095610
优化3300次后成本是: 0.093075
优化3400次后成本是: 0.090667
优化3500次后成本是: 0.088374
优化3600次后成本是: 0.086190
优化3700次后成本是: 0.084108
优化3800次后成本是: 0.082119
优化3900次后成本是: 0.080219
优化4000次后成本是: 0.078402
优化4100次后成本是: 0.076662
优化4200次后成本是: 0.074994
优化4300次后成本是: 0.073395
优化4400次后成本是: 0.071860
优化4500次后成本是: 0.070385
优化4600次后成本是: 0.068968
优化4700次后成本是: 0.067604
优化4800次后成本是: 0.066291
优化4900次后成本是: 0.065027
优化5000次后成本是: 0.063807
对训练图片的预测准确率为: 100.0%%%%
对测试图片的预测准确率为: 70.0%%%%
在这里插入图片描述


学习率调过好几次,准确率似乎只有70%%%%,要想提高要么改进网络结构要么学习更多数据。

代码都在上一个博客里,函数都一样,应用的背景不一样而已。

推荐阅读