Bucles con theano
En esta página
Uso de escaneo básico
scan
se usa para llamar a la función varias veces sobre una lista de valores, la función puede contener estado.
Sintaxis scan
(a partir de theano 0.9):
scan(
fn,
sequences=None,
outputs_info=None,
non_sequences=None,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=None,
name=None,
profile=False,
allow_gc=None,
strict=False)
Esto puede ser muy confuso a primera vista. Explicaremos varios usos básicos pero importantes de “escaneo” en varios ejemplos de código.
Los siguientes ejemplos de código asumen que ha ejecutado importaciones:
import numpy as np
import theano
import theano.tensor as T
secuencias
- Mapea una función sobre una lista
En el caso más simple, escanear simplemente asigna una función pura (una función sin estado) a una lista. Las listas se especifican en el argumento secuencias
s_x = T.ivector()
s_y, _ = theano.scan(
fn = lambda x:x*x,
sequences = [s_x])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5]) #[1,4,9,16,25]
Tenga en cuenta que scan
tiene dos valores de retorno, el primero es la lista resultante y el último son las actualizaciones del valor de estado, que se explicarán más adelante.
secuencias
- Comprima una función sobre una lista
Casi lo mismo que arriba, solo dale al argumento secuencias
una lista de dos elementos. El orden de los dos elementos debe coincidir con el orden de los argumentos en fn
s_x1 = T.ivector()
s_x2 = T.ivector()
s_y, _ = theano.scan(
fn = lambda x1,x2:x1**x2,
sequences = [s_x1, s_x2])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]
outputs_info
- Acumular una lista
La acumulación implica una variable de estado. Las variables de estado necesitan valores iniciales, que se especificarán en el parámetro outputs_info
.
s_x = T.ivector()
v_sum = th.shared(np.int32(0))
s_y, update_sum = theano.scan(
lambda x,y:x+y,
sequences = [s_x],
outputs_info = [s_sum])
fn = theano.function([s_x], s_y, updates=update_sum)
v_sum.get_value() # 0
fn([1,2,3,4,5]) # [1,3,6,10,15]
v_sum.get_value() # 15
fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
v_sum.get_value() # 0
Ponemos una variable compartida en outputs_info
, esto hará que scan
devuelva actualizaciones a nuestra variable compartida, que luego se puede poner en theano.function
.
non_sequences
y n_steps
- Órbita del mapa logístico x -> lambda*x*(1-x)
Puede dar entradas que no cambien durante scan
en el argumento non_sequences
. En este caso, s_lambda
es una variable que no cambia (pero NO una constante, ya que debe proporcionarse durante el tiempo de ejecución).
s_x = T.fscalar()
s_lambda = T.fscalar()
s_t = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x,l: l*x*(1-x),
outputs_info = [s_x],
non_sequences = [s_lambda],
n_steps = s_t
)
fn = theano.function([s_x, s_lambda, s_t], s_y)
fn(.75, 4., 10) #a stable orbit
#[ 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75]
fn(.65, 4., 10) #a chaotic orbit
#[ 0.91000003, 0.32759991, 0.88111287, 0.41901192, 0.97376364,
# 0.10219204, 0.3669953 , 0.92923898, 0.2630156 , 0.77535355]
Grifos - Fibonacci
los estados/entradas pueden venir en múltiples pasos de tiempo. Esto se hace por:
-
poner
dict(input=<init_value>, taps=<list of int>)
dentro del argumentosequences
. -
poner
dict(initial=<init_value>, taps=<list of int>)
dentro del argumentooutputs_info
.
En este ejemplo, usamos dos toques en outputs_info
para calcular la relación de recurrencia x_n = x_{n-1} + x_{n-2}
.
s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x1,x2: x1+x2,
outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
theo mapear y reducir
theano.map
y theano.scan_module.reduce
son contenedores de theano_scan
. Pueden verse como una versión para minusválidos de scan
. Puede ver la sección Uso de escaneo básico como referencia.
import theano
import theano.tensor as T
s_x = T.ivector()
s_sqr, _ = theano.map(
fn = lambda x:x*x,
sequences = [s_x])
s_sum, _ = theano.reduce(
fn = lambda: x,y:x+y,
sequences = [s_x],
outputs_info = [0])
fn = theano.function([s_x], [s_sqr, s_sum])
fn([1,2,3,4,5]) #[1,4,9,16,25], 15
haciendo un ciclo while
A partir de theano 0.9, los bucles while se pueden realizar a través de theano.scan_module.scan_utils.until
.
Para usarlo, debe devolver el objeto hasta
en fn
de scan
.
En el siguiente ejemplo, construimos una función que verifica si un número complejo está dentro del conjunto de Mandelbrot. Un número complejo z_0
está dentro del conjunto de Mandelbrot si la serie z_{n+1} = z_{n}^2 + z_0
no converge.
MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
fn = iterate,
outputs_info = [0, s_z0],
non_sequences = [s_z0],
n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
return fn_mandelbrot_iters(z_)>=MAX_ITER
is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False