Boucles avec theano

Utilisation de base de l’analyse

scan est utilisé pour appeler la fonction plusieurs fois sur une liste de valeurs, la fonction peut contenir un état.

Syntaxe scan (à partir de theano 0.9):

scan(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    n_steps=None,
    truncate_gradient=-1,
    go_backwards=False,
    mode=None,
    name=None,
    profile=False,
    allow_gc=None,
    strict=False)

Cela peut être très déroutant à première vue. Nous expliquerons plusieurs utilisations basiques mais importantes de “scan” dans plusieurs exemples de code.

Les exemples de code suivants supposent que vous avez exécuté des importations :

import numpy as np
import theano
import theano.tensor as T

sequences - Mappez une fonction sur une liste

Dans le cas le plus simple, scan ne fait qu’associer une fonction pure (une fonction sans état) à une liste. Les listes sont spécifiées dans l’argument sequences

  s_x = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x:x*x,
      sequences = [s_x])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5]) #[1,4,9,16,25]

Notez que scan a deux valeurs de retour, la première est la liste résultante et la seconde est les mises à jour de la valeur d’état, qui seront expliquées plus tard.

sequences - Compressez une fonction sur une liste

Presque comme ci-dessus, donnez simplement à l’argument “séquences” une liste de deux éléments. L’ordre des deux éléments doit correspondre à l’ordre des arguments dans fn

  s_x1 = T.ivector()
  s_x2 = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x1,x2:x1**x2,
      sequences = [s_x1, s_x2])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]

outputs_info - Accumuler une liste

L’accumulation implique une variable d’état. Les variables d’état ont besoin de valeurs initiales, qui doivent être spécifiées dans le paramètre outputs_info.

  s_x = T.ivector()
  v_sum = th.shared(np.int32(0))
  s_y, update_sum = theano.scan(
      lambda x,y:x+y,
      sequences = [s_x],
      outputs_info = [s_sum])
  fn = theano.function([s_x], s_y, updates=update_sum)
  
  v_sum.get_value() # 0
  fn([1,2,3,4,5]) # [1,3,6,10,15]
  v_sum.get_value() # 15
  fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
  v_sum.get_value() # 0

Nous mettons une variable partagée dans outputs_info, cela entraînera scan retour des mises à jour de notre variable partagée, qui peuvent ensuite être mises dans theano.function.

non_sequences et n_steps - Orbite de la carte logistique x -> lambda*x*(1-x)

Vous pouvez donner des entrées qui ne changent pas pendant scan dans l’argument non_sequences. Dans ce cas, s_lambda est une variable non changeante (mais PAS une constante puisqu’elle doit être fournie pendant l’exécution).

  s_x = T.fscalar()
  s_lambda = T.fscalar()
  s_t = T.iscalar()
  s_y, _ = theano.scan(
      fn = lambda x,l: l*x*(1-x),
      outputs_info = [s_x],
      non_sequences = [s_lambda],
      n_steps = s_t
  )
  fn = theano.function([s_x, s_lambda, s_t], s_y)

  fn(.75, 4., 10) #a stable orbit

  #[ 0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75]

  fn(.65, 4., 10) #a chaotic orbit

  #[ 0.91000003,  0.32759991,  0.88111287,  0.41901192,  0.97376364,
  # 0.10219204,  0.3669953 ,  0.92923898,  0.2630156 ,  0.77535355]

Taps - Fibonacci

les états/entrées peuvent se présenter en plusieurs étapes de temps. Cela se fait par :

  • mettre dict(input=<init_value>, taps=<list of int>) dans l’argument sequences.

  • mettre dict(initial=<init_value>, taps=<list of int>) dans l’argument outputs_info.

Dans cet exemple, nous utilisons deux taps dans outputs_info pour calculer la relation de récurrence x_n = x_{n-1} + x_{n-2}.

s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
    fn = lambda x1,x2: x1+x2,
    outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
    n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]

theo mapper et réduire

theano.map et theano.scan_module.reduce sont des enveloppes de theano_scan. Ils peuvent être considérés comme une version handicapée de scan. Vous pouvez consulter la section Utilisation de base de l’analyse pour référence.

import theano
import theano.tensor as T
s_x = T.ivector()
s_sqr, _ = theano.map(
    fn = lambda x:x*x,
    sequences = [s_x])
s_sum, _ = theano.reduce(
    fn = lambda: x,y:x+y,
    sequences = [s_x],
    outputs_info = [0])
fn = theano.function([s_x], [s_sqr, s_sum])
fn([1,2,3,4,5]) #[1,4,9,16,25], 15

faire une boucle while

Depuis theano 0.9, les boucles while peuvent être effectuées via theano.scan_module.scan_utils.until. Pour l’utiliser, vous devez renvoyer l’objet until dans fn de scan.

Dans l’exemple suivant, nous construisons une fonction qui vérifie si un nombre complexe est à l’intérieur de l’ensemble de Mandelbrot. Un nombre complexe z_0 est à l’intérieur de l’ensemble de Mandelbrot si la série z_{n+1} = z_{n}^2 + z_0 ne converge pas.

MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
    return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
    fn = iterate,
    outputs_info = [0, s_z0],
    non_sequences = [s_z0],
    n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
    return fn_mandelbrot_iters(z_)>=MAX_ITER

is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False