Bucles con theano

Uso de escaneo básico

scan se usa para llamar a la función varias veces sobre una lista de valores, la función puede contener estado.

Sintaxis scan (a partir de theano 0.9):

scan(
    fn,
    sequences=None,
    outputs_info=None,
    non_sequences=None,
    n_steps=None,
    truncate_gradient=-1,
    go_backwards=False,
    mode=None,
    name=None,
    profile=False,
    allow_gc=None,
    strict=False)

Esto puede ser muy confuso a primera vista. Explicaremos varios usos básicos pero importantes de “escaneo” en varios ejemplos de código.

Los siguientes ejemplos de código asumen que ha ejecutado importaciones:

import numpy as np
import theano
import theano.tensor as T

secuencias - Mapea una función sobre una lista

En el caso más simple, escanear simplemente asigna una función pura (una función sin estado) a una lista. Las listas se especifican en el argumento secuencias

  s_x = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x:x*x,
      sequences = [s_x])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5]) #[1,4,9,16,25]

Tenga en cuenta que scan tiene dos valores de retorno, el primero es la lista resultante y el último son las actualizaciones del valor de estado, que se explicarán más adelante.

secuencias - Comprima una función sobre una lista

Casi lo mismo que arriba, solo dale al argumento secuencias una lista de dos elementos. El orden de los dos elementos debe coincidir con el orden de los argumentos en fn

  s_x1 = T.ivector()
  s_x2 = T.ivector()
  s_y, _ = theano.scan(
      fn = lambda x1,x2:x1**x2,
      sequences = [s_x1, s_x2])
  fn = theano.function([s_x], s_y)
  fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]

outputs_info - Acumular una lista

La acumulación implica una variable de estado. Las variables de estado necesitan valores iniciales, que se especificarán en el parámetro outputs_info.

  s_x = T.ivector()
  v_sum = th.shared(np.int32(0))
  s_y, update_sum = theano.scan(
      lambda x,y:x+y,
      sequences = [s_x],
      outputs_info = [s_sum])
  fn = theano.function([s_x], s_y, updates=update_sum)
  
  v_sum.get_value() # 0
  fn([1,2,3,4,5]) # [1,3,6,10,15]
  v_sum.get_value() # 15
  fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
  v_sum.get_value() # 0

Ponemos una variable compartida en outputs_info, esto hará que scan devuelva actualizaciones a nuestra variable compartida, que luego se puede poner en theano.function.

non_sequences y n_steps - Órbita del mapa logístico x -> lambda*x*(1-x)

Puede dar entradas que no cambien durante scan en el argumento non_sequences. En este caso, s_lambda es una variable que no cambia (pero NO una constante, ya que debe proporcionarse durante el tiempo de ejecución).

  s_x = T.fscalar()
  s_lambda = T.fscalar()
  s_t = T.iscalar()
  s_y, _ = theano.scan(
      fn = lambda x,l: l*x*(1-x),
      outputs_info = [s_x],
      non_sequences = [s_lambda],
      n_steps = s_t
  )
  fn = theano.function([s_x, s_lambda, s_t], s_y)

  fn(.75, 4., 10) #a stable orbit

  #[ 0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75,  0.75]

  fn(.65, 4., 10) #a chaotic orbit

  #[ 0.91000003,  0.32759991,  0.88111287,  0.41901192,  0.97376364,
  # 0.10219204,  0.3669953 ,  0.92923898,  0.2630156 ,  0.77535355]

Grifos - Fibonacci

los estados/entradas pueden venir en múltiples pasos de tiempo. Esto se hace por:

  • poner dict(input=<init_value>, taps=<list of int>) dentro del argumento sequences.

  • poner dict(initial=<init_value>, taps=<list of int>) dentro del argumento outputs_info.

En este ejemplo, usamos dos toques en outputs_info para calcular la relación de recurrencia x_n = x_{n-1} + x_{n-2}.

s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
    fn = lambda x1,x2: x1+x2,
    outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
    n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]

theo mapear y reducir

theano.map y theano.scan_module.reduce son contenedores de theano_scan. Pueden verse como una versión para minusválidos de scan. Puede ver la sección Uso de escaneo básico como referencia.

import theano
import theano.tensor as T
s_x = T.ivector()
s_sqr, _ = theano.map(
    fn = lambda x:x*x,
    sequences = [s_x])
s_sum, _ = theano.reduce(
    fn = lambda: x,y:x+y,
    sequences = [s_x],
    outputs_info = [0])
fn = theano.function([s_x], [s_sqr, s_sum])
fn([1,2,3,4,5]) #[1,4,9,16,25], 15

haciendo un ciclo while

A partir de theano 0.9, los bucles while se pueden realizar a través de theano.scan_module.scan_utils.until. Para usarlo, debe devolver el objeto hasta en fn de scan.

En el siguiente ejemplo, construimos una función que verifica si un número complejo está dentro del conjunto de Mandelbrot. Un número complejo z_0 está dentro del conjunto de Mandelbrot si la serie z_{n+1} = z_{n}^2 + z_0 no converge.

MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
    return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
    fn = iterate,
    outputs_info = [0, s_z0],
    non_sequences = [s_z0],
    n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
    return fn_mandelbrot_iters(z_)>=MAX_ITER

is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False