Skip to content

Commit

Permalink
Update lab-07-1-learning_rate_and_evaluation.py
Browse files Browse the repository at this point in the history
1. Modified to a learning rate that is appropriate for learning.
2. The arg_max is modified to argmax.
  • Loading branch information
qoocrab authored and kkweon committed Jan 13, 2019
1 parent 107d2eb commit be4ea88
Showing 1 changed file with 6 additions and 21 deletions.
27 changes: 6 additions & 21 deletions lab-07-1-learning_rate_and_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
[1, 0, 0],
[1, 0, 0]]


# Evaluation our model using this test dataset
x_test = [[2, 1, 1],
[3, 1, 2],
Expand All @@ -41,12 +40,11 @@
# Cross entropy cost/loss
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
# Try to change learning_rate to small numbers
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=1e-10).minimize(cost)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(cost)

# Correct prediction Test model
prediction = tf.arg_max(hypothesis, 1)
is_correct = tf.equal(prediction, tf.arg_max(Y, 1))
prediction = tf.argmax(hypothesis, 1)
is_correct = tf.equal(prediction, tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

# Launch graph
Expand All @@ -55,8 +53,7 @@
sess.run(tf.global_variables_initializer())

for step in range(201):
cost_val, W_val, _ = sess.run(
[cost, W, optimizer], feed_dict={X: x_data, Y: y_data})
cost_val, W_val, _ = sess.run([cost, W, optimizer], feed_dict={X: x_data, Y: y_data})
print(step, cost_val, W_val)

# predict
Expand All @@ -66,6 +63,7 @@

'''
when lr = 1.5
0 5.73203 [[-0.30548954 1.22985029 -0.66033536]
[-4.39069986 2.29670858 2.99386835]
[-3.34510708 2.09743214 -0.80419564]]
Expand All @@ -87,27 +85,19 @@
6 nan [[ nan nan nan]
[ nan nan nan]
[ nan nan nan]]
...
Prediction: [0 0 0]
Accuracy: 0.0
-------------------------------------------------
When lr = 1e-10
0 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
1 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
2 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
...
198 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
199 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
Expand All @@ -125,12 +115,7 @@
1 3.318 [[ 0.66219079 0.74796319 -1.14612854]
[-0.81948912 0.03000021 1.68936598]
[ 0.23214608 -0.33772916 -1.94628811]]
2 2.0218 [[ 0.64342022 0.74127686 -1.12067163]
[-0.81161296 -0.00900121 1.72049117]
[ 0.2086665 -0.35079569 -1.909742 ]]
...
199 0.672261 [[-1.15377033 0.28146935 1.13632679]
[ 0.37484586 0.18958236 0.33544877]
[-0.35609841 -0.43973011 -1.25604188]]
Expand Down

0 comments on commit be4ea88

Please sign in to comment.