Loops com theano
Uso básico de varredura
scan
é usado para chamar a função várias vezes em uma lista de valores, a função pode conter estado.
Sintaxe scan
(a partir do theano 0.9):
scan(
fn,
sequences=None,
outputs_info=None,
non_sequences=None,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
mode=None,
name=None,
profile=False,
allow_gc=None,
strict=False)
Isso pode ser muito confuso à primeira vista. Explicaremos vários usos básicos, mas importantes, do scan
em vários exemplos de código.
Os exemplos de código a seguir pressupõem que você executou importações:
import numpy as np
import theano
import theano.tensor as T
sequences
- Mapeia uma função sobre uma lista
No caso mais simples, scan apenas mapeia uma função pura (uma função sem estado) para uma lista. As listas são especificadas no argumento sequences
s_x = T.ivector()
s_y, _ = theano.scan(
fn = lambda x:x*x,
sequences = [s_x])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5]) #[1,4,9,16,25]
Nota scan
tem dois valores de retorno, o primeiro é a lista resultante, e o último é a atualização do valor do estado, que será explicado mais tarde.
sequences
- Compacte uma função sobre uma lista
Quase o mesmo que acima, apenas dê ao argumento sequences
uma lista de dois elementos. A ordem dos dois elementos deve corresponder à ordem dos argumentos em fn
s_x1 = T.ivector()
s_x2 = T.ivector()
s_y, _ = theano.scan(
fn = lambda x1,x2:x1**x2,
sequences = [s_x1, s_x2])
fn = theano.function([s_x], s_y)
fn([1,2,3,4,5],[0,1,2,3,4]) #[1,2,9,64,625]
outputs_info
- Acumule uma lista
A acumulação envolve uma variável de estado. As variáveis de estado precisam de valores iniciais, que devem ser especificados no parâmetro outputs_info
.
s_x = T.ivector()
v_sum = th.shared(np.int32(0))
s_y, update_sum = theano.scan(
lambda x,y:x+y,
sequences = [s_x],
outputs_info = [s_sum])
fn = theano.function([s_x], s_y, updates=update_sum)
v_sum.get_value() # 0
fn([1,2,3,4,5]) # [1,3,6,10,15]
v_sum.get_value() # 15
fn([-1,-2,-3,-4,-5]) # [14,12,9,5,0]
v_sum.get_value() # 0
Colocamos uma variável compartilhada em outputs_info
, isso fará com que scan
retorne atualizações para nossa variável compartilhada, que pode então ser colocada em theano.function
.
non_sequences
e n_steps
- Órbita do mapa logístico x -> lambda*x*(1-x)
Você pode dar entradas que não mudam durante o scan
no argumento non_sequences
. Neste caso, s_lambda
é uma variável imutável (mas NÃO uma constante, pois deve ser fornecida durante o tempo de execução).
s_x = T.fscalar()
s_lambda = T.fscalar()
s_t = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x,l: l*x*(1-x),
outputs_info = [s_x],
non_sequences = [s_lambda],
n_steps = s_t
)
fn = theano.function([s_x, s_lambda, s_t], s_y)
fn(.75, 4., 10) #a stable orbit
#[ 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75]
fn(.65, 4., 10) #a chaotic orbit
#[ 0.91000003, 0.32759991, 0.88111287, 0.41901192, 0.97376364,
# 0.10219204, 0.3669953 , 0.92923898, 0.2630156 , 0.77535355]
Torneiras - Fibonacci
estados/entradas podem vir em vários passos de tempo. Isso é feito por:
-
colocando
dict(input=<init_value>, taps=<list of int>)
dentro do argumentosequences
. -
colocando
dict(initial=<init_value>, taps=<list of int>)
dentro do argumentooutputs_info
.
Neste exemplo, usamos dois toques em outputs_info
para calcular a relação de recorrência x_n = x_{n-1} + x_{n-2}
.
s_x0 = T.iscalar()
s_x1 = T.iscalar()
s_n = T.iscalar()
s_y, _ = theano.scan(
fn = lambda x1,x2: x1+x2,
outputs_info = [dict(initial=T.join(0,[s_x0, s_x1]), taps=[-2,-1])],
n_steps = s_n
)
fn_fib = theano.function([s_x0, s_x1, s_n], s_y)
fn_fib(1,1,10)
# [2, 3, 5, 8, 13, 21, 34, 55, 89, 144]
mapear e reduzir theo
theano.map
e theano.scan_module.reduce
são wrappers de theano_scan
. Eles podem ser vistos como uma versão deficiente do scan
. Você pode ver a seção Uso básico de varredura para referência.
import theano
import theano.tensor as T
s_x = T.ivector()
s_sqr, _ = theano.map(
fn = lambda x:x*x,
sequences = [s_x])
s_sum, _ = theano.reduce(
fn = lambda: x,y:x+y,
sequences = [s_x],
outputs_info = [0])
fn = theano.function([s_x], [s_sqr, s_sum])
fn([1,2,3,4,5]) #[1,4,9,16,25], 15
fazendo loop while
A partir do theano 0.9, os loops while podem ser feitos via theano.scan_module.scan_utils.until
.
Para usar, você deve retornar o objeto until
em fn
de scan
.
No exemplo a seguir, construímos uma função que verifica se um número complexo está dentro do conjunto de Mandelbrot. Um número complexo z_0
está dentro do conjunto de mandelbrot se a série z_{n+1} = z_{n}^2 + z_0
não converge.
MAX_ITER = 256
BAILOUT = 2.
s_z0 = th.cscalar()
def iterate(s_i_, s_z_, s_z0_):
return [s_z_*s_z_+s_z0_,s_i_+1], {}, until(T.abs_(s_z_)>BAILOUT)
(_1, s_niter), _2 = theano.scan(
fn = iterate,
outputs_info = [0, s_z0],
non_sequences = [s_z0],
n_steps = MAX_ITER
)
fn_mandelbrot_iters = theano.function([s_z0], s_niter)
def is_in_mandelbrot(z_):
return fn_mandelbrot_iters(z_)>=MAX_ITER
is_in_mandelbrot(0.24+0.j) # True
is_in_mandelbrot(1.j) # True
is_in_mandelbrot(0.26+0.j) # False