数字を分類するニューラルネットワークの実装をやってみる2-2


数字を分類するニューラルネットワークの実装をやってみる2-1の続きです。

GitHubにあるnetwork2.pyをやってみます。


network2.pyでは

MNISTの分類にクロス(交差)エントロピーを使っています。

 

ubuntu 16.04 LTS + Python 3.x用にnetwork2.pyを修正

 

$python3

>>>import mnist_loader
>>>training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
>>>import network2
>>>net = network2.Network([784, 30, 10], cost=network2.CrossEntropyCost)
>>>net.large_weight_initializer()
>>>net.SGD(training_data, 30, 10, 0.5, evaluation_data=test_data,
monitor_evaluation_accuracy=True)

 

 

 


network2.py

学習したモデルはsave()関数で保存できます。

>>>net.save(“nn2”)

ネットワークモデル、重み、バイアス、コスト関数名がJSON形式で保存されます。

 


ソース(Python3.x)

Be the first to comment

Leave a Reply

Your email address will not be published.


*