Tensorflow : Understanding tf.contrib.seq2seq.BasicDecoder
Tensorflow : Understanding tf.contrib.seq2seq.BasicDecoder
I am trying to understand tf.contrib.seq2seq.BasicDecoder, Every example on web just use that wrapper but I couldn't find the explanation of what actually tf.contrib.seq2seq.BasicDecoder doing , I tried with one simple example :
import numpy as np
import tensorflow as tf
from pprint import pprint
from tensorflow.python.framework import tensor_shape
from tensorflow.contrib.rnn import BasicRNNCell
from tensorflow.contrib.seq2seq.python.ops.basic_decoder import BasicDecoder, BasicDecoderOutput
from tensorflow.contrib.seq2seq.python.ops.helper import TrainingHelper
from tensorflow.python.layers.core import Dense
sequence_length = [3, 4, 3, 1, 3]
batch_size = 5
max_time = 8
input_size = 7
hidden_size = 10
output_size = 3
inputs = np.random.randn(batch_size, max_time, input_size).astype(np.float32)
output_layer = Dense(output_size) # will get a trainable variable size [hidden_size x output_size]
dec_cell = BasicRNNCell(hidden_size)
helper = TrainingHelper(inputs, sequence_length)
decoder = BasicDecoder(
cell=dec_cell,
helper=helper,
initial_state=dec_cell.zero_state(dtype=tf.float32, batch_size=batch_size),
output_layer=output_layer)
first_finished, first_inputs, first_state = decoder.initialize()
(first_finished, first_inputs, first_state)
step_outputs, step_state, step_next_inputs, step_finished = decoder.step(
tf.constant(0), first_inputs, first_state)
(step_outputs, step_state, step_next_inputs, step_finished)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
results = sess.run({
"batch_size": decoder.batch_size,
"first_finished": first_finished,
"first_inputs": first_inputs,
"first_state": first_state,
"step_outputs": step_outputs,
"step_state": step_state,
"step_next_inputs": step_next_inputs,
"step_finished": step_finished})
pprint(results)
output is :
{'batch_size': 5,
'first_finished': array([False, False, False, False, True]),
'first_inputs': array([[-0.1305329 , 0.7027261 , -0.8157375 , 0.01787353, 2.3610914 ,
0.8905939 , -0.2685608 ],
[-1.1782284 , 1.6488065 , 0.58254075, 0.12861735, 0.47683764,
-2.05314 , -0.166469 ],
[ 0.8365086 , -1.7963833 , -2.5053551 , 2.3320568 , -0.357463 ,
-0.01917691, 0.5789354 ],
[-1.7942209 , -0.19699056, 0.42065838, -0.81790465, 2.5130792 ,
1.2232817 , 0.7819383 ],
[ 1.2460921 , -0.16332811, 0.70908403, -1.334465 , -0.10106717,
-0.26541698, -1.3249161 ]], dtype=float32),
'first_state': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
'step_finished': array([False, False, False, True, True]),
'step_next_inputs': array([[ 1.3291198 , -0.15886226, 1.4437864 , 0.41159418, 0.55492574,
-0.90773547, 0.83662 ],
[ 1.0856647 , 2.3009017 , 1.2625048 , -0.7682241 , -0.58327836,
-1.2566029 , 0.32073924],
[ 0.2532574 , 1.3086783 , -0.6277142 , 1.8158357 , -0.9641214 ,
-0.4462067 , -0.11307725],
[ 0.48346692, -0.58842784, 0.4114005 , 0.23313236, -0.81712246,
-1.4564492 , 0.7117556 ],
[ 0.7588838 , -0.82005906, 0.663568 , 0.24783312, -1.4573535 ,
1.4284246 , -0.30952594]], dtype=float32),
'step_outputs': BasicDecoderOutput(rnn_output=array([[ 1.4097914 , -0.69918895, -1.2088122 ],
[-1.266958 , -0.8121094 , -0.03660662],
[ 0.40251616, -0.11823708, 0.23454508],
[ 1.3780088 , -0.86239576, -0.9247706 ],
[ 0.09462224, -0.14165601, 0.39751652]], dtype=float32), sample_id=array([0, 2, 0, 0, 2], dtype=int32)),
'step_state': array([[-0.19132493, 0.8753218 , 0.07888561, -0.6356789 , 0.72481483,
0.4161568 , 0.7337458 , 0.06502081, 0.20294249, -0.73887783],
[ 0.4778563 , 0.1592015 , -0.86701995, 0.8127028 , 0.09732129,
-0.9266094 , -0.5395306 , -0.8694291 , 0.87705237, -0.545192 ],
[ 0.66678804, 0.82219815, 0.9689762 , -0.9692538 , -0.3958014 ,
0.24547155, 0.05074365, 0.0893333 , -0.5242875 , 0.18463017],
[-0.8668696 , 0.9405894 , -0.69780034, -0.1462304 , 0.9349755 ,
0.41605997, 0.9185027 , -0.07991812, -0.5194315 , -0.5538262 ],
[ 0.47941405, -0.8954227 , -0.7062361 , 0.3774918 , 0.28503373,
0.617851 , -0.36548492, 0.2932893 , 0.3323133 , -0.35999647]],
dtype=float32)}
I got it that it's returning rnn output and sample_id but i have confusion about time finished boolean output
So tf.contrib.seq2seq.BasicDecoder step function parameter are :
step(
time,
inputs,
state,
name=None
)
Now what actually time represent here ? if my sequence length is [3, 4, 3, 1, 3] now if i pass decoder.step(tf.constant(1), step_next_inputs, step_state)
decoder.step(tf.constant(1), step_next_inputs, step_state)
output is:
array([False, False, False, True, True]))}
so it means 5,4 sequence are unrolled , it means i have to pass sequence length as input , so i tried :
decoder.step(tf.constant(3), step_next_inputs, step_state)
decoder.step(tf.constant(3), step_next_inputs, step_state)
output should be :
array([True, False, True, True, True]))}
but i am getting:
array([ True, True, True, True, True]))}
How this is working and even if i am passing any arbitary value , then it's not giving error , it means it can unroll arbitary times ?
Here is google colab notebook , You can run this code online on my notebook
Please provide info about this .
Thank you !
@nbro , Hi here is google colab notebook , you can run online , colab.research.google.com/drive/…
– Ayodhyankit Paul
Jul 2 at 22:28
You should definitely start by removing unused (commented) parts of your code
– desertnaut
Jul 3 at 3:52
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.
It would be very nice and helpful if you could simplify your code as much as it is sufficient to illustrate the problem, otherwise it is a little painful to have to read all the code.
– nbro
Jul 2 at 22:22