Boucles avec theano
Utilisation de base de l’analyse
scan
est utilisé pour appeler la fonction plusieurs fois sur une liste de valeurs, la fonction peut contenir un état.
Syntaxe scan
(à partir de theano 0.9):
scan(
fn,
sequences=None,
outputs_info=None,
non_sequences=None,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=None,
name=None,
profile=False,
allow_gc=None,
strict=False)
Cela peut être très déroutant à première vue. Nous expliquerons plusieurs utilisations basiques mais importantes de “scan” dans plusieurs exemples de code.
Les exemples de code suivants supposent que vous avez exécuté des importations :
import numpy as np
import theano
import theano.tensor as T
sequences
- Mappez une fonction sur une liste
Dans le cas le plus simple, scan ne fait qu’associer une fonction pure (une fonction sans état) à une liste. Les listes sont spécifiées dans l’argument sequences
s_x = T.ivector()
s_y, _ = theano.scan(
fn = lambda x:x*x,
sequences = [s_x])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5]) #[1,4,9,16,25]
Notez que scan
a deux valeurs de retour, la première est la liste résultante et la seconde est les mises à jour de la valeur d’état, qui seront expliquées plus tard.
sequences
- Compressez une fonction sur une liste
Presque comme ci-dessus, donnez simplement à l’argument “séquences” une liste de deux éléments. L’ordre des deux éléments doit correspondre à l’ordre des arguments dans fn
s_x1 = T.ivector()
s_x2 = T.ivector()
s_y, _ = theano.scan(
fn = lambda x1,x2:x1**x2,
sequences = [s_x1, s_x2])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]
outputs_info
- Accumuler une liste
L’accumulation implique une variable d’état. Les variables d’état ont besoin de valeurs initiales, qui doivent être spécifiées dans le paramètre outputs_info
.
s_x = T.ivector()
v_sum = th.shared(np.int32(0))
s_y, update_sum = theano.scan(
lambda x,y:x+y,
sequences = [s_x],
outputs_info = [s_sum])
fn = theano.function([s_x], s_y, updates=update_sum)
v_sum.get_value() # 0
fn([1,2,3,4,5]) # [1,3,6,10,15]
v_sum.get_value() # 15
fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
v_sum.get_value() # 0
Nous mettons une variable partagée dans outputs_info
, cela entraînera scan
retour des mises à jour de notre variable partagée, qui peuvent ensuite être mises dans theano.function
.
non_sequences
et n_steps
- Orbite de la carte logistique x -> lambda*x*(1-x)
Vous pouvez donner des entrées qui ne changent pas pendant scan
dans l’argument non_sequences
. Dans ce cas, s_lambda
est une variable non changeante (mais PAS une constante puisqu’elle doit être fournie pendant l’exécution).
s_x = T.fscalar()
s_lambda = T.fscalar()
s_t = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x,l: l*x*(1-x),
outputs_info = [s_x],
non_sequences = [s_lambda],
n_steps = s_t
)
fn = theano.function([s_x, s_lambda, s_t], s_y)
fn(.75, 4., 10) #a stable orbit
#[ 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75]
fn(.65, 4., 10) #a chaotic orbit
#[ 0.91000003, 0.32759991, 0.88111287, 0.41901192, 0.97376364,
# 0.10219204, 0.3669953 , 0.92923898, 0.2630156 , 0.77535355]
Taps - Fibonacci
les états/entrées peuvent se présenter en plusieurs étapes de temps. Cela se fait par :
-
mettre
dict(input=<init_value>, taps=<list of int>)
dans l’argumentsequences
. -
mettre
dict(initial=<init_value>, taps=<list of int>)
dans l’argumentoutputs_info
.
Dans cet exemple, nous utilisons deux taps dans outputs_info
pour calculer la relation de récurrence x_n = x_{n-1} + x_{n-2}
.
s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x1,x2: x1+x2,
outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
theo mapper et réduire
theano.map
et theano.scan_module.reduce
sont des enveloppes de theano_scan
. Ils peuvent être considérés comme une version handicapée de scan
. Vous pouvez consulter la section Utilisation de base de l’analyse pour référence.
import theano
import theano.tensor as T
s_x = T.ivector()
s_sqr, _ = theano.map(
fn = lambda x:x*x,
sequences = [s_x])
s_sum, _ = theano.reduce(
fn = lambda: x,y:x+y,
sequences = [s_x],
outputs_info = [0])
fn = theano.function([s_x], [s_sqr, s_sum])
fn([1,2,3,4,5]) #[1,4,9,16,25], 15
faire une boucle while
Depuis theano 0.9, les boucles while peuvent être effectuées via theano.scan_module.scan_utils.until
.
Pour l’utiliser, vous devez renvoyer l’objet until
dans fn
de scan
.
Dans l’exemple suivant, nous construisons une fonction qui vérifie si un nombre complexe est à l’intérieur de l’ensemble de Mandelbrot. Un nombre complexe z_0
est à l’intérieur de l’ensemble de Mandelbrot si la série z_{n+1} = z_{n}^2 + z_0
ne converge pas.
MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
fn = iterate,
outputs_info = [0, s_z0],
non_sequences = [s_z0],
n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
return fn_mandelbrot_iters(z_)>=MAX_ITER
is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False