from __future__ import print_function
import keras
from keras.models import Sequential
from keras.layers import Conv2D,MaxPooling2D
from keras.layers import Dense,Dropout,Flatten
from keras.datasets import mnist
from keras import backend as K

batch_size = 256
epochs = 10
num_classes = 10
img_rows,img_cols = 28,28

(x_train,y_train),(x_test,y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0],1,img_rows,img_cols)
    x_test = x_test.reshape(x_test.shape[0],1,img_rows,img_cols)
    input_shape = (1,img_rows,img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0],img_rows, img_cols,1)
    x_test = x_test.reshape(x_test.shape[0],img_rows, img_cols,1)
    input_shape = (img_rows, img_cols,1)
#x_train = x_train.astype('float32')
#x_test = x_test.astype('float32')
#x_train /=255
#x_test /=255
print("x_train shape:",x_train.shape[0])
print(x_train.shape[0],'train samples')
print(x_test.shape[0],'test samples')

#y_train = keras.utils.np_utils.to_categorical(y_train,num_classes)
#y_test = keras.utils.np_utils.to_categorical(y_test,num_classes)

#creat model
model = Sequential()
#model.add(Conv2D(32,(3,3),activation='relu',input_shape = input_shape))#(1)源程序，0.99
#model.add(Dense(32,input_shape = input_shape,activation='relu'))#(2)全部用Dense的效果比用了Con2D的效果差点，accuracy: 0.9607
model.add(Conv2D(32,(5,5),activation='relu',input_shape = input_shape))#0.9869
model.add(MaxPooling2D(2,2))
model.add(Conv2D(64,(3,3),activation='relu'))
model.add(MaxPooling2D(2,2))
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(32,activation='relu'))
model.add(Dense(num_classes,activation="softmax"))

#compile
model.compile(loss=keras.metrics.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy'])
#fit  model
model.fit(x_train,y_train,batch_size = batch_size,epochs = epochs,verbose=1,validation_data=(x_test,y_test))
print("fit end")
#evaluate
score = model.evaluate(x_test,y_test,verbose=0)
print('Test loss:',score[0])
print('Test accuracy:',score[1])


