Using if condition inside the TensorFlow graph with tf.cond

Parameters

Parameter Details
pred a TensorFlow tensor of type `bool`
fn1 a callable function, with no argument
fn2 a callable function, with no argument
name (optional) name for the operation

Remarks

• `pred` cannot be just `True` or `False`, it needs to be a Tensor
• The function `fn1` and `fn2` should return the same number of outputs, with the same types.

Basic example

```x = tf.constant(1.)
bool = tf.constant(True)

# sess.run(res) will give you 2.
```

When f1 and f2 return multiple tensors

The two functions `fn1` and `fn2` can return multiple tensors, but they have to return the exact same number and types of outputs.

```x = tf.constant(1.)
bool = tf.constant(True)

def fn1():

def fn2():

res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]
```

define and use functions f1 and f2 with parameters

You can pass parameters to the functions in tf.cond() using lambda and the code is as bellow.

```x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
return tf.mul(a, b)

def fn2(a, b):

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))
```

Then you can call it as bellowing:

```with tf.Session() as sess:
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
# The result is 2.0
print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
# The result is 5.0
```

This article is an extract of the original Stack Overflow Documentation created by contributors and released under CC BY-SA 3.0. This website is not affiliated with Stack Overflow