로그인 바로가기 하위 메뉴 바로가기 본문 바로가기

인공지능 및 기계학습 심화

임시 이미지 KAIST 산업및시스템공학과 문일철 교수 KOOC (KAIST Open Online Course)
http://www.edwith.org/aiml-adv/forum/130181
좋아요 689 수강생 3324

안녕하세요, 조교 김동준입니다.

지난시간에 diffusion model의 forward-time Data SDE는 특정한 형태인 VE나 VP의 형태를 가지는 것을 살펴보았고,

이것이 일반적인 SDE가 되었을 때에는 두 가지 어려움으로 ( x_t\vert x_0  xtx0 샘플링,  p_{0t}(x_t\vert x_0)  p0t(xtx0) 계산) 인하여 학습이 불가능함을 살펴보았습니다.

이번 시간에는 두 가지 어려움 중  p_{0t}(x_t\vert x_0)  p0t(xtx0) 계산 불가능성을 해결했다고 주장하는 논문을 살펴보도록 하겠습니다.

논문은 Diffusion Normalizing Flow (https://arxiv.org/abs/2110.07579)를 참조하였습니다.

해당 논문에서는 일반적인 형태의  dx_t=f(x_t,t)dt+g(t)dw_t  dxt=f(xt,t)dt+g(t)dwt 의 forward-time data SDE로 정한 후 reverse-time data SDE의  dx_t=[f(x_t,t)-g^2(t)\nabla\log{p_t(x_t)}]d\bar{t}+g(t)d\bar{w}_t  dxt=[f(xt,t)g2(t)logpt(xt)]dt¯+g(t)dw¯t 에서의 data score인  \nabla\log{p_t(x_t)}  logpt(xt) 를 score neural network인  s_\theta(x_t,t)  sθ(xt,t) 를 통하여 예측하여  dx_t=[f(x_t,t)-g^2(t)s_\theta(x_t,t)]d\bar{t}+g(t)d\bar{w}_t  dxt=[f(xt,t)g2(t)sθ(xt,t)]dt¯+g(t)dw¯t 와 같은 reverse-time generative SDE를 만듭니다.

사용되는 forward-time SDE가 VE인  dx_t=g(t)dw_t  dxt=g(t)dwt  혹은 VP인  dx_t=-\frac{1}{2}\beta(t)x_tdt+\sqrt{\beta(t)}dw_t  dxt=2 1β(t)xtdt+β(t) dwt 의 형태가 아닌, general SDE라는 점을 제외하면 논문에서 하는 일은 기존 diffusion model에서 하는 일과 정확히 동일하다는 것을 알 수 있습니다.

여기서, 원래의 denoising diffusion loss는  L(\theta;\lambda)=\int_{0}^{T}\lambda(t)E_{x_{0},x_{t}}[\Vert s_\theta(x_t,t)-\nabla\log{p_{0t}(x_t\vert x_0)}\Vert_{2}^{2}]dt  L(θ;λ)=0Tλ(t)Ex0,xt[sθ(xt,t)logp0t(xtx0)22]dt 입니다. 이 loss function을 discretized Markov chain에 대하여 구해볼까요? discretized Markov chain은 연속 시간  t  t 에 대하여 존재하는 stochastic process인  \{x_{t}\}_{t=0}^{T}  {xt}t=0T 를 불연속 시간  t  t 에 대한 또 다른 stochastic process인  \{x_{t_n}\}_{n=0}^{N}  {xtn}n=0N 으로 근사하여 구할 수 있습니다. 이 discrete Markov chain에 대한 diffusion model은 DDPM에서 정의된 바와 같이, stochastic process  \{x_{t_n}\}_{n=0}^{N}  {xtn}n=0N 의 joint distribution인  q(x_{t_0},...,x_{t_{N}})  q(xt0,...,xtN) 과 우리가 실제로 손으로 생성할 수 있는 generative stochastic process의 joint distribution인  p_{\theta}(x_{t_{0}},...,x_{t_{N}})  pθ(xt0,...,xtN) 사이의 KL divergence를 최적화하여 구할 수 있습니다.

즉, continuous-time을 discrete-time으로 낮출 때는 stochastic process를 discrete-time으로 이산화할 필요가 있습니다. 여기서, VE 혹은 VP의 경우에는 linear하기 때문에 stochastic process의 정확한 solution을 구할 수 있고, 이산화할 때 큰 오차가 없이 이산화할 수 있습니다. 반면 general SDE의 경우에는 어떻게 될까요?

일반적인 SDE의 이산화의 정확도는 보통 drift term (및 volatility term)의 Lipschitz constant에 비례하게 됩니다. 그렇기 때문에 만약 우리의 관심사가 continuous-time SDE를 보다 정확히 근사하는 discrete-time stochastic process를 구하여 그 process를 예측하는 것에 있다면 general SDE로 무작정 접근하는 방법은 error control을 하기 어렵기 때문에 지양해야 합니다.

그럼에도 불구하고 하지 않는 것보다는 좀 에러가 있다 해도 한번 근사해본 후에 학습을 해보면 어떨까요? discrete-time stochastic process는 다음과 같이  x_{t_{i}}^{EM}=x_{t_{i-1}}^{EM}+f(x_{t_{i-1}}^{EM},t_{i-1})(t_{i}-t_{i-1})+g(t_{i-1})\epsilon\sqrt{t_{i}-t_{i-1}}  xtiEM=xti1EM+f(xti1EM,ti1)(titi1)+g(ti1)ϵtiti1  이라고 근사할 수 있어서 이제 더이상 비선형성은 사라지게 되었습니다. 그렇기 때문에 data perturbation은 Gaussian distribution에 따라 진행되고 결국 기존에 사용하였던 loss function을 그대로 사용할 수 있게 되는 것입니다.

이는 중요한 점을 야기합니다. 즉, continuous-time을 discrete-time으로 근사하는 과정에서 비선형성을 없애 discrete transition probability를 Gaussian으로 만들었습니다. 이러한 근사가 없이는 Gaussianity가 더이상 만족되지 않기 때문에 우리는 loss function을 tractable하게 만들 수 없게 되는 것입니다.

비선형성을 나타내는  f(x_t,t)  f(xt,t) 가 당장에  x_t^2  xt2 에만 비례해도 Lipschitz constant는 무한대가 되어 discrete-time 근사화의 에러를 bound할 수 없게 됩니다. 그렇기 때문에 이러한 형태의 general SDE 학습은 아주 제한 범위가 좁다고 이야기할 수 있겠습니다.

다음 시간에는 위의 문제를 semi-linear diffusion model인 Schrodinger Bridge Problem에서 어떻게 해결하였는지 살펴보겠습니다.

감사합니다