๋ณธ ํฌ์คํธ๋ ์ ๊ฐ ํด๋จผ์ค์ผ์ดํ ๊ธฐ์ ๋ธ๋ก๊ทธ์ ๋จผ์ ์์ฑํ๊ณ ์ฎ๊ธด ํฌ์คํธ์
๋๋ค.
๋ณธ ํฌ์คํธ์์๋ ์ด์ ์ ํฌ์คํธํ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ์ธย StackGAN์ Conditioning Augmentation Layer ์์ conditioning vector ๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ณผ์ ์ ๊ทผ์์ด ๋๋ ๋
ผ๋ฌธ์ ๋ํด์ ๋ฆฌ๋ทฐํ๋ ค๊ณ ํฉ๋๋ค. ๋ฆฌ๋ทฐํ๋ ค๋ ๋
ผ๋ฌธ์ ์ ๋ชฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
โAuto-Encoding Variational Bayesโ
Objective
๋
ผ๋ฌธ์ ๋ฐฐ๊ฒฝ์ ์ฌ์ ํ๋ฅ ๋ถํฌ๋ฅผ ํ์ตํ๊ธฐ ์ํด ๊ตฌ์ฑํ ์ฌํ ํ๋ฅ ๋ถํฌ๊ฐ intractable ํ์ผ๋ฉฐ, ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์กด์ฌํ๋ Monte Carlo Estimation ๋ฐฉ๋ฒ ํจ์จ์ ์ด์ง ๋ชปํ๋ ๊ฒ์์ ์์ํฉ๋๋ค.
์์ ๊ฐ์ด ์ค๋ช
๋๋ฆฌ๋ฉด ์ด๊ฒ ๋ฌด์จ ์๋ฆฌ์งโฆ? ํ์ค ๋ถ๋ค์ด ๊ฝค๋ ๋ง์ผ์ค ๊ฒ ๊ฐ์์ ์ฐ์ ๋ฐฐ๊ฒฝ์ ๋ํ ์ค๋ช
์ ํ๊ธฐ ์ด์ ์ ์๋์์ ํ์ํ ๋ฐฐ๊ฒฝ ์ง์๋ถํฐ ์ฐจ๊ทผ์ฐจ๊ทผ ์ค๋ช
ํ๋ ค๊ณ ํฉ๋๋ค.
Background
[์ฃผ์]ย ์ดํดํด์ผ ํ ๋ฐฐ๊ฒฝ์ง์์ดย ๋ง์ต๋๋ค. ์ต๋ํ ์ดํด๋๋ฅผ ๋์ด๋ ๋ฐฉํฅ์ผ๋ก ์์ ํ์ผ๋ ๋ง์์ ์ค๋น๋ฅผ ํด ์ฃผ์ธ์.
๊ฒฐ๋ก ๋ถํฐ ๋ง์๋๋ฆฌ์๋ฉด VAE ๋ Generative Model ์
๋๋ค. ์ด์ ์ ์ ๊ฐ ์ผ๋ ํฌ์คํธ ์ค์๋ย ๋น์ทํ ์น๊ตฌ๊ฐ ์์ต๋๋ค. ๋น์ทํ ์น๊ตฌ์ธ GAN ์ ๋ํด์ ์ฒ์์ด์ ๋ถ๋ค์ ์ํด์ ๊ฐ๋จํ ์ค๋ช
์ ๋ถ์ด์๋ฉด GAN์ noise vector ๋ก ๋ถํฐ ํน์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
๊ณผ ์ ์ฌํ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ์์ฑํด๋ด๋ model ์
๋๋ค. ๋์น์ฑ์
จ๊ฒ ์ง๋ง, VAE ๋ ๋น์ทํ ๊ธฐ๋ฅ์ ํฉ๋๋ค. ๋ค๋ง ๋ค๋ฅธ ์ ์ input ์ด noise vector ๊ฐ ์๋ ์ด๋ฏธ์ง์
๋๋ค. ์ฌ์ค GAN ์ด ์ ๋ช
ํด์ ๊ทธ๋ ์ง, VAE ๊ฐ GAN ๋ณด๋ค ์ด์ ์ ์ธ์์ ๊ณต๊ฐ๋ Generative Model ์
๋๋ค.
ํ ๋ฒ ๋ ์ธ๊ธํ๋ ๊ฒ์ด๊ฒ ์ง๋ง, VAE ์์ ํ์ตํ๋ ๊ฒ์ ์ด๋ฏธ์ง๋ก๋ถํฐ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํ model parameters ๋ค์
๋๋ค. ์ฆ, VAE ์์ ๊ถ๊ทน์ ์ผ๋ก ์ํ๋ ๊ฒ์ ๋ชจ๋ธ์ ํตํด์ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ํ์ต์ ์ฌ์ฉ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
์ ๋ํด ๊ฐ์ง๋ ์ ์ฌ๋๋ฅผ ๊ฐ์ฅ ๋์ผ ์ ์๋ model parameters ๋ฅผ ํ์ตํ๋ ๊ฒ์
๋๋ค.
์ ๋ฌธ์ฅ์ ๊ณฑ์น์ด ์๊ฐํด๋ด
์๋ค. ์ ์ฌ๋๋ฅผ ํ๋ฅ ์ด๋ผ๊ณ ๋ณด์์ ๋, ํ๋ฅ ์ด ๊ฐ์ฅ ๋์ ๋ฌด์ธ๊ฐ๋ฅผ ์ ํํ๋ ๊ฒ์ ์ด๋๊ฐ ์ต์ํ ์ํฉ์
๋๋ค.
MLE(Maximum Likelihood Estimation)ย ๋ก ์ ์๋ ค์ง ์ถ๋ก ๋ฐฉ๋ฒ์ ์ด๋ค ์ํฉ์ด ์ฃผ์ด์ก์ ๋ ๊ทธ ์ํฉ์ ๊ฐ์ฅ ๋์ ํ๋ฅ ๋ก ์ฐ์ถํ๋ ํ๋ณด๊ตฐ์ ์ ํํ๋ ๋ฐฉ๋ฒ์
๋๋ค. ๊ฐ๋จํ๋ฉด์๋ ์ ๋ช
ํ ์์๋ก ๋
์ ๋จ์ด์ง ๋จธ๋ฆฌ์นด๋ฝ์ ๋ฐ๊ฒฌํ๋๋ฐ, ์ด ๊ฒ์ด ๋จ์์ ๋จธ๋ฆฌ์นด๋ฝ์ธ์ง ์ฌ์์ ๋จธ๋ฆฌ์นด๋ฝ์ธ์ง๋ฅผ ์ ํํ๋ ๊ฒฝ์ฐ๋ฅผ ๋ค ์ ์์ต๋๋ค. MLE ๋ ๋จธ๋ฆฌ์นด๋ฝ์ ๊ธธ์ด ๋ณ ๊ทธ ๊ฒ์ด ๋จ์์ ๊ฒ์ผ ํ๋ฅ ๋ถํฌ๋ฅผ ํตํด ์ต์ข
์ ์ผ๋ก ํ๋ฅ 0.5 ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ถ๋ก ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฌ๋ฆฌ ํ๊ฒ ๋ฉ๋๋ค.
ํ์ง๋ง, MLE ์๋ ์น๋ช
์ ์ธ ๋จ์ ์ด ์์ต๋๋ค. ๋ณดํต์ ๊ฒฝ์ฐ์ ๋ฐ์ํ๋ ์ถ๋ก ์ย ์ฌํ ํ๋ฅ ๋ถํฌ๊ฐ ์กด์ฌํ์ง ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค๋ ์ ์
๋๋ค. ์ ์์๋ ๊ทธ๋ด์ธํ์ง๋ง, โ๋จธ๋ฆฌ์นด๋ฝ์ ๊ธธ์ด ๋ณ ๊ทธ ๊ฒ์ด ๋จ์์ ๊ฒ์ผ ํ๋ฅ ๋ถํฌโ ๋ ๋ณดํต์ ๊ฒฝ์ฐ์๋ ์กด์ฌํ์ง ์๋๋ค๋ ๊ฒ์ด์ฃ . ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ํ๋ ๋ฐฉ๋ฒ์ด MAP ์
๋๋ค.
MAP(Maximum A Posteriori)ย ๋ Bayes Theorem ๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋ฌธ์ ๋ฅผ ํ๊ฐํ ์ถ๋ก ๋ฐฉ๋ฒ์
๋๋ค.
PR(A|B) ๋ฅผ ์ ํฌ๊ฐ ๊ตฌํ๊ณ ์ถ์ B: ๋จธ๋ฆฌ์นด๋ฝ์ ๊ธธ์ด๊ฐ X ์ผ ๋, A: ๊ทธ๊ฒ์ด ๋จ์์ ๊ฒ์ผ ํ๋ฅ ์ด๋ผ๊ณ ํ๋ค๋ฉด,์ด๋ PR(A): ์ ์ฒด ์ฌ๋๋ค ์ค ํน์ ์ฌ๋์ด ๋จ์์ผ ํ๋ฅ ์ PR(B|A): ๋จ์๊ฐ X ๊ธธ์ด์ ๋จธ๋ฆฌ์นด๋ฝ์ ๊ฐ์ง ํ๋ฅ ์ ๊ณฑํ ๋ค์ PR(B): ์ ์ฒด ๋จธ๋ฆฌ์นด๋ฝ๋ค ์ค ๊ธธ์ด๊ฐ X์ผ ํ๋ฅ ๋ก ๋๋์ด์ ๊ตฌํ ์ ์๋ค๋ ๊ฒ์
๋๋ค.
ํ์คํ, โ์ ์ฒด ์ฌ๋๋ค ์ค ํน์ ์ฌ๋์ด ๋จ์์ผ ํ๋ฅ โ, โ ๋จ์๊ฐ X ๊ธธ์ด์ ๋จธ๋ฆฌ์นด๋ฝ์ ๊ฐ์ง ํ๋ฅ โ, โ์ ์ฒด ๋จธ๋ฆฌ์นด๋ฝ๋ค ์ค ๊ธธ์ด๊ฐ X์ผ ํ๋ฅ โ ์ ์์ ๊ตฌํ๋ ค๊ณ ํ๋ โ๋จธ๋ฆฌ์นด๋ฝ์ ๊ธธ์ด๊ฐ X ์ผ ๋, ๊ทธ๊ฒ์ด ๋จ์์ ๊ฒ์ผ ํ๋ฅ " ๋ณด๋ค ๋ฏธ๋ฆฌ ์๊ณ ์์ ๊ฐ๋ฅ์ฑ์ด ๋์๋ณด์
๋๋ค.
๋ค์ VAE ๋ก ๋์๊ฐ๋ด
์๋ค.
์ ํฌ๋ โ๋ชจ๋ธ์ ํตํด์ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ํ์ต์ ์ฌ์ฉ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
์ ๋ํด ๊ฐ์ง๋ ์ ์ฌ๋๋ฅผ ๊ฐ์ฅ ๋์ผ ์ ์๋ model parameters ๋ฅผ ํ์ตโ ์ด MLE ์ ์ถ๋ก ๊ณผ ๋น์ทํ๋ค๋ ์ ์ ๋๊ผ๊ณ , ์ด MLE ์ ๋ฌธ์ ์ ์ ํด๊ฒฐํ๊ธฐ ์ํด MAP ๋ฅผ ์ ์ํ ์ ์๊ฒ ๋์์ต๋๋ค. ์กฐ๊ธ ๋ ์ง์ ์ ์ผ๋ก ๋ง์๋๋ฆฌ๋ฉด, ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Model parameters ์ ๋ณต์กํ โ์กฐํฉ๋ณ๋ก ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ํ์ต์ ์ฌ์ฉ๋ ์ด๋ฏธ์ง๋ค๊ณผ ๊ฐ์ง๋ ์ ์ฌ๋์ ๋ํ ํ๋ฅ ๋ถํฌโ๋ฅผ ์๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์ MLE ๊ฐ ์๋ MAP ๋ฅผ ์ฌ์ฉํ๋ ค๊ณ ์๋ํ ์ ์๋ค๋ ๊ฒ์
๋๋ค.
ํ์ง๋ง, MAP ๋ฅผ ์ฌ์ฉํ๋ ค๊ณ ํด๋ ๋ฌธ์ ์ ์ด ๋ํ๋ฉ๋๋ค.
์ฌ์ ํ๋ฅ p_theta(x) ๋ฅผ ๊ตฌํ๊ธฐ ์ํด์ ์์ ๊ฐ์ ์์ ๊ณ์ฐํด์ผ ํ๋๋ฐ ๋ฌด์์ธ์ง๋ ๋ชจ๋ฅผ ์กฐํฉ๋ค์ ๊ฐ์ง latent variables ๋ค์ ์ ๋ถํ๋ค๋ ๊ฒ์ด ๊ฐ๋ฅํ์ง ์์์ต๋๋ค. ์ด๋ ๊ฒ ์ฌ์ ํ๋ฅ ๋ถํฌ์กฐ์ฐจ ๊ณ์ฐ ๋ถ๊ฐ๋ฅํ ์ ์ ์ธ๊ณต ์ ๊ฒฝ๋ง์ย Intractabilityย ๋ก ์ธ๊ธํ์ฌ ๋ฌธ์ ์ ์ด๋ผ๊ณ ์นญํ๊ณ ์์ต๋๋ค.
์ด๋ฌํ Intractability ์ ๋ง์๋ ๋ฐฉ๋ฒ ์ค ๊ฐ์ฅ naive ํ ๋ฐฉ๋ฒ์ดย Monte Carlo Estimationย ์ ์ฌ์ฉํ๋ ๊ฒ์
๋๋ค. Monte Carlo Estimation ์ ์ํ์ ๋ฐ๋ณต์ ํตํด์ ํ๋ฅ ๋ถํฌ๋ฅผ ์ป์ด๋ด๋ ๋ฐฉ๋ฒ์
๋๋ค. ์ด ๊ฒฝ์ฐ์๋ ์ฌ์ ํ๋ฅ ๋ถํฌ๋ฅผ ์๊ธฐ ์ด๋ ต๊ธฐ ๋๋ฌธ์ ๋ฐ๋ณต ์ํ์ ํตํด latent variables ์ ๋ฐ๋ผ ์์ฑ๋๋ ์ด๋ฏธ์ง ํ๋ฅ ๋ถํฌ์ ๋ํ ๊ธฐ๋๊ฐ์ ๊ณ์ฐํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค. ํ์ง๋ง, Monte Carlo Estimation ์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ํ๋์ theta๊ฐ ๋ฌ๋ผ์ง batch ์์์ ๊ฐ๊ฐ sampling ์ ์งํํด์ผ ํ๊ณ ํฐ ๋ฐ์ดํฐ์
์์ sampling ์ ์งํํ๋ค๋ ๊ฒ์ย ํ์ต ์๋๋ฅผ ๋๋ฆฌ๊ฒ ํ๋ค๋ ๋จ์ ์ด ์์ต๋๋ค.
์ด ์์ ์์ ๋ค์ย Objectiveย ์ ์ ํ ๊ธ๊ท๋ก ๋์๊ฐ๋ณด์ธ์. ์ด๋์ ๋ ์ดํด๊ฐ ๋์ค ๊ฒ์ด๋ผ ๋ฏฟ์ต๋๋ค.
Qualitative Objective
์์ ์ฌ์ ํ๋ฅ ๊ฐ intractable ํ ๋ฌธ์ ์ ์ง๋ฉดํ์์ต๋๋ค. ๋ฐ๋ผ์ ์ด๋ฌํ ๊ฒฝ์ฐ์์๋ MAP ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
๊ณฐ๊ณฐํ ๋ค์ ์๊ฐํด๋ด
์๋ค.
์์ด๋ฌ๋ํ ๊ฒ์ ์ ํฌ๊ฐ ์ํ๋ ๊ฒ์ ์ฃผ์ด์ง ๋ฐ์ดํฐ์
์ด๋ฏธ์ง์ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ํ๋ฅ ์ด ๋์ model paramters theta ๋ฅผ ํ์ตํ๋ ๊ฒ์ด๋ผ๋ ์ ์
๋๋ค. ์ด๊ฒ ๋ฌด์จ ์๋ฆฌ์ธ๊ฐ ํ๋ฉด, ์ฌ์ค ๋ฅผ ์ต๋ํํ๋ theta ๋ฅผ ๊ตฌํ๋ฉด ์ ํฌ๋ ์ํ๋ ๋ชฉํ๋ฅผ ๋ฌ์ฑํ๋ค๋ ๊ฒ์
๋๋ค.
์ฅโฆ? ๊ทธ๋ฌ๋ฉด ์ง๊ธ๊น์ง MLE ๋ฉฐ, MAP ๋ฉฐ ํ๋ ์น๊ตฌ๋ค์ ์ ์๊ฐํ๊ฑฐ์ง?? ๋ผ๋ ์๋ฌธ์ด ๋ค ์ ์์ต๋๋ค. ๊ทธ๋ฅ ์ด๋ฏธ์ง X์ ๋ํด X๋ฅผ ๊ฐ์ฅ ์ ์์ฑํด๋ด๋ model paramters ๋ฅผ ์์ฑํด์ผ ํ๋ค๋ ์๋ก ๋ถํฐ ์์ํ์ผ๋ฉด ๋์ ๊ฒ ๊ฐ์๋ฐ ๋ง์
๋๋ค.
๊ฐ์ฅ ํฐ ์ด์ ๋ model paramters ์ ์ด๋ฏธ์ง์์ ๊ด๊ณ๋ฅผ ์ง์ ์ ์ผ๋ก ํ์ตํ๊ธฐ์ ๋๋ฌด ๋ณต์กํ๊ธฐ ๋๋ฌธ์
๋๋ค. ๊ทธ๋ ๊ธฐ์ ๊ทธ ์ค๊ฐ ๋จ๊ณ์ธ latent variables z ๋ผ๋ ์น๊ตฌ๋ฅผ ๋ค์ฌ๊ฐ๋ฉฐ ์ด๋ฏธ์ง์ ํต์ฌ ์์๋ง์ ๊ตฌ์ฑํ๋ vector ๋ฅผ ๋์ด ์๊ฐํ ๊ฒ์
๋๋ค. Latent variables ์ ์ด๋ฏธ์ง ๊ฐ์ ํ์ตํ๊ธฐ ์ฉ์ดํ๊ธฐ ๋๋ฌธ์ธ ๊ฒ์
๋๋ค. ๊ทธ๋์ ๋ฅผ ๊ณ ์ ํด ๋ ๋ค ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ฅ ์ ์์ฑํ ์ ์๋ Latent variables ๋ฅผ ๋ฝ์๋ด๊ณ ์ด๋ค์ ๋ ์ ๋ํด์ ๋น๊ตํด๊ฐ๋ฉด์ ํ๋ ์ฐพ์๋ด๋ ์๋๋ฅผ ์๊ฐํด๋ณผ ์ ์์๋ ๊ฒ์
๋๋ค.
์๋๊ฐ ๊ธธ์๋๋ฐ, ๊ฒฐ๊ตญ ์ ํฌ๊ฐ ์ํ๋ ๊ฒ์ ๋ฅผ ์ต๋ํํ๋ theta ์ด๊ณ , MAP ์ ๋ํ ์๊ฐ์ ์ง๊ธ๋ถํฐ ๋ฒ๋ฆฌ์
๋ ์ข์ต๋๋ค. ๋จ์ง ๋ฐฐ๊ฒฝ ์ค๋ช
์ ์ํด ๋์
ํ ์น๊ตฌ์์ ๋ฟ์
๋๋ค. ์ง๊ธ๋ถํฐ ์ค๋ช
๋๋ฆด ๋ด์ฉ์ด ์ด ๋
ผ๋ฌธ์ ์ฃผ์ ๋ด์ฉ์
๋๋ค.
Variational Inference
์ ํต์ ์ผ๋ก intractable ํ p_theta(x) ๋ถ๋ถ์ ๊ตฌํด๋ด๊ธฐ ์ํดย Variational Inferenceย ๋ฅผ ์ฌ์ฉํ์ต๋๋ค. Variational Inference ๋ฅผ ํ ๋ง๋๋ก ํํํ์๋ฉด ๊ตฌํ๊ธฐ ์ด๋ ค์ด ์ฌํ ํ๋ฅ ๋ถํฌ๋ฅผ ์ ํฌ๊ฐ ์๊ณ ์๋ ํ๋ฅ ๋ถํฌ๋ก ๊ทผ์ฌํ๋ ๊ฒ์
๋๋ค.
์ด ๋, ์๊ณ ์๋ ํ๋ฅ ๋ถํฌ๋ ํ๋ฅ ๋ถํฌ๋ฅผ ์ ์ํ๊ณ ์๋ ๋ณ์๋ค์ ์ํด ๊ฒฐ์ ๋๊ณ , ๊ตฌํ๋ ค๊ณ ํ๋ ์ฌํ ํ๋ฅ ๋ถํฌ๋ ์ด๋ ํ ์ฌ์ ์ฌ๊ฑด์ ์ํฅ์ ๋ฐ์์ ๊ฒฐ์ ๋ฉ๋๋ค. ์๊ณ ์๋ ํ๋ฅ ๋ถํฌ๋ฅผ g, ํ๊ฒ ์ฌํ ํ๋ฅ ๋ถํฌ๋ฅผ f ๋ผ๊ณ ํ๋ฉด, g(X|phi) ์ f(X|theta) ๋ฑ์ผ๋ก ํํํ ์ ์๋ ๊ฒ์
๋๋ค.
์ด ๋, Variational Inference ์์๋ ๋ ํ๋ฅ ๋ถํฌ์ ์ฐจ์ด๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก phi ๋ฅผ ์กฐ์ ํด ๋๊ฐ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ ํ๋ฅ ๋ถํฌ์ ์ฐจ์ด๋ KL Divergence(Kullback Leibler Divergence) ๋ฅผ ํตํด ๊ตฌํ๊ฒ ๋ฉ๋๋ค.
๋์น ๋น ๋ฅด์ ๋ถ๋ค์ P~Q ์ผ ๊ฒฝ์ฐ์ KL Divergence ๊ฐ 0์ ๊ฐ๊น์์ง๋ค๋ ๊ฒ์ ์์ค ์ ์์ต๋๋ค. ์ด๋ฌํ ์ฑ์ง์ ์ด์ฉํด KL Divergence ๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก phi ๋ฅผ ์ต์ ํํ์ฌ ๊ทผ์ฌ ํ๋ฅ ๋ถํฌ๋ฅผ ๊ตฌํด๋ธ๋ค๊ณ ๋ณด์๋ฉด ๋ฉ๋๋ค. ์ด๋ฌํ ๋ชฉ์ ์ฑ์ ์๋์ ๊ฐ์ด ์ ๋ฆฌํ ์ ์์ต๋๋ค.
Variational Inference ๋ statistical inference ๋ฌธ์ ์๋ย posterior estimation ์ optimization ๋ฌธ์ ๋ก ๋ณ๊ฒฝํ๋ค๋ ๊ฒ์ ์์๊ฐ ์์ต๋๋ค. ํ์ง๋ง ์ด๋ฌํ ๋ฐฉ๋ฒ์ phi ์ ๋ํ update ์ย ์ถ๊ฐ์ ์ผ๋ก iteration์ ์จ์ผ ํ๋ค๋ ๋จ์ ์ด ์กด์ฌํ์ต๋๋ค.
VAE
VAE ์์๋ ์์ Variational Inference ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์ฌ ์๋ก์ด ์ ์์ ํฉ๋๋ค. Optimization ์ผ๋ก ๋ณํํ ๋ฌธ์ ๋ฅผ iterative ํ๊ฒ update ํ์ง ์๊ณ ์์ g ๋ก ํ๊ธฐํ๋ posterior q(z|x) ์ likelihood p(x|z) ๊ฐ ๊ฐ๊ฐ encoder ์ decoder ์ธย auto encoder ๋ฅผ ๋ชจ๋ธ๋งย ํ ๊ฒ์
๋๋ค.
์ด๊ฒ ๋ฌด์จ ์๋ฆฐ๊ฐ ํ๋ฉดโฆ input ์ผ๋ก ๋ค์ด์จ ์ด๋ฏธ์ง๋ก๋ถํฐ latent variables ๋ฅผ ์์ฑํ๋ encoder ๋ถ๋ถ๊ณผ ์์ฑํ latent variables ๋ก๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ decoder ๋ถ๋ถ์ ๋๋์ด์ ํ ๋ฒ์ update ๋ก posterior ์ likelihood ๋ฅผ ๋ชจ๋ update ํ ์ ์๋๋ก ์ค๊ณํ ๊ฒ์
๋๋ค.
โ๊ทธ๋ ์ข์, ๊ธฐ์กด์ ์ธ๊ธํ๋ update ๋ฅผ ์ํ ์ถ๊ฐ์ ์ธ iteration ์ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ๋ณ๊ฒฝํจ์ผ๋ก์จ ํด๊ฒฐํ๊ตฌ๋. ๊ทธ๋ฐ๋ฐ ํ์ต์ ์ํ ๊ธฐ์ค์ ์ด๋ป๊ฒ ์ธ์ฐ์ง?โ ๋ผ๋ ์๋ฌธ์ด ๋์ ๋ค๋ฉด ์ ์ดํดํ์๊ณ ๊ณ์ ๊ฒ๋๋ค. ์ด์ ๊ทธ ๊ธฐ์ค์ ๋ํด์ ์ค๋ช
ํ๋ ค๊ณ ํฉ๋๋ค.
Variational Bound
์๊น ๋ง์๋๋ ธ๋ ์ด์ผ๊ธฐ๋ฅผ ๋ค์ ํ ๋ฒ ์ธ๊ธํ๊ฒ ์ต๋๋ค. ์ ํฌ๋ p_theta(x) ๋ฅผ ์ต๋ํํ๋ theta ๋ฅผ ๊ตฌํ๋ ค๊ณ ํ๋ ๊ฒ์
๋๋ค. ์ด๋ฅผ ์ํด์ ๊ณ์ฐ์ ํธ์์ฑ์ ์ํด log(p_thata(x))์ ๋ํด์ ์ ๊ทผ์ ์๋ํฉ๋๋ค. ์ด๋ฅผ Marginal (log) likelihood ๋ผ๊ณ ํฉ๋๋ค. ์ ํฌ๊ฐ ํ์ต์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
์ด x(1), x(2), โฆ x(n) ์ด๋ผ ๊ฐ์ ํ ๋ ์๋์ ๊ฐ์ด ์๋์ ๊ฐ์ด theta ์ ๋ํด ์ฃผ์ด์ง ๋ฐ์ดํฐ์
๊ณผ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํด๋ผ ํ๋ฅ p_theta ๋ฅผ ์๋ฏธํ๋ term ์ ์์ฑํด๋ผ ์ ์๋ ๊ฒ์
๋๋ค.
์ฌ๊ธฐ์ KL Divergence ์ ๋ํ ์ ๊ฐ๋ฅผ ํตํด์ ์ญ์ผ๋ก log(p_theta(x)) ์ ๋ํ lower bound ๋ฅผ ๊ตฌํด๋ผ ์ ์์ต๋๋ค. ์ด๋ KL Divergence ๊ฐ 0 ์ด์์ด๋ผ๋ ์ฆ๋ช
ํ์ ์ด๋ฃจ์ด์ง ์ ๊ฐ๊ณผ์ ์
๋๋ค.(์ฆ๋ช
์ ์ ์ฌ์ง์ ์ต ํ๋จ์ ์กด์ฌํฉ๋๋ค. ์ฆ๋ช
์ฐ์ธก์ X>0 ์กฐ๊ฑด์ด ๋น ์ก๋ค์..ใ
) ์ ์ฌ์ง์ ์๋์์ ์ธ ๋ฒ์งธ ์ค ํญ๋ชฉ์ ELBO(Evidence of Lower BOund) ๋ผ๊ณ ๋ถ๋ฆ
๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ํด์ ๊ตฌํด๋ธ lower bound ๋ฅผ ๋ณํํ๋ฉด ์์ ๊ฐ์ ์์ผ๋ก ๋ณํํฉ๋๋ค. ๋จ์ํ ์ํ์ ๊ณ์ฐ์ผ๋ก ๋ณํํด ๋ธ ๊ฒ์ด ์ด๋ค ์๋ฏธ๊ฐ ์์๊น ์ถ์ง๋ง ์์์ ์๋ฏธํ๋ ๋ฐ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.
์๋์์ ์ธ ๋ฒ์งธ ์ค์ ์์ RHS ์ ๋ ๋ฒ์งธ ํญ์ Reconstruction Error ๋ก ๋ณผ ์ ์์ต๋๋ค. Input image x(i) ๋ก๋ถํฐ ์์ฑ๋ Latent variables ์ ํ๋ฅ ๋ถํฌ์์ ์ผ๋ง๋งํผ์ ๋ค์ input image x(i) ๋ฅผ decoder ๋ก reconstruct ํ ์ ์๊ฒ๋ model parameters ๊ฐ ์ ์ค์ ๋์ด ์๋๋๋ฅผ ๋ณด๋ ํญ๋ชฉ์
๋๋ค. ์ฌ๊ธฐ์ ์์ KL Divergence term ์ ์ถ๊ฐํ์ฌ ์๋ ์ฒ์์ ์๋ํ๋ q_phi ๋ผ๋ inference ๋ฅผ ์ํ ํ๋ฅ ๋ถํฌ๊ฐ p_theta ์ ๋น์ทํ๊ฒ๋ ๋ง๋๋ regularization ๋ ์งํํฉ๋๋ค.
์ข
ํฉํ์ฌ ์ lower bound ๋ฅผ maximize ํ๋ ๊ฒ์ด ์ ํฌ์ ์ต์ข
๋ชฉํ๋ก ๋ณ๊ฒฝ๋์์ต๋๋ค.
Reparametrization trick
์ต์ข
์ ์ผ๋ก model parameter update ์ ๋ํ ๊ธฐ์ค์ ์ธ์ ์ต๋๋ค. ํ์ง๋ง, ๋ฌธ์ ์ ์ด ์์ต๋๋ค.
์์์ ์ป์ด๋ธ ์ ํฌ๊ฐ ์ง์คํด์ผ ํ lower bound function ์ ๋ณด์๋ฉด ๊ธฐ๋๊ฐ์ ๊ตฌํ๊ธฐ ์ํด์ q_phi ๋ถํฌ์์ sampling ํ๋ ๊ณผ์ ์ด ํ์ํฉ๋๋ค. Forward propagation ์์ ์ด๋ฌํ sampling ์ ์๊ฐ์ ์ค๋ ํ์๋ก ํ ๋ฟ๋ง ์๋๋ผ backward propagation ์์๋ ์์ sampling ์์ฒด๊ฐ ๋ฏธ๋ถ ๊ฐ๋ฅํ ์ฐ์ฐ์ด ์๋๊ธฐ ๋๋ฌธ์ ์ด๋ฐ ํํ๋ก ์ค๊ณํ๋ค๋ฉด ์ฌ๋ฐ๋ฅด๊ฒ ํ์ต์ ์ค๊ณํ ์ ์์ต๋๋ค.
์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ๊ณผ์ ์ดย Reparametrization Trickย ์
๋๋ค. ๋ง ๊ทธ๋๋ก ๋ณ์๋ฅผ ๋ค๋ฅธ ํํ๋ก ์นํํ์ฌ ๊ผผ์๋ฅผ ๋ถ๋ฆฌ๋ ๊ณผ์ ์
๋๋ค.
๊ฒฐ๋ก ๋ถํฐ ๋ง์๋๋ฆฌ์๋ฉด q_phi ํ๋ฅ ๋ถํฌ๋ฅผ ๋ฐ๋ฅด๋ ๋ณ์๋ฅผ sampling ํ๋ ๊ฒ์ด ์๋, q_phi ํ๋ฅ ๋ถํฌ๋ฅผ ๊ตฌ์ฑํ๋ ์์๋ค๊ณผ p(epsilon) ์ ๋ฐ๋ฅด๋ ์๋ก์ด vector epsilon ์ ์ด์ฉํด์ ์๋กญ๊ฒ ํ๋ฅ ๋ณ์๋ฅผ ๊ณ์ฐํ๊ฒ ๋ฉ๋๋ค. ์ด๋ ๊ฒ ๊ณ์ฐํ๋ ๊ณผ์ ์ ํตํด์ ํ๋ฅ ๋ณ์์ ๋ฏธ๋ถ ๊ฐ๋ฅ์ฑ์ ๋ถ์ฌํ ์ ์๊ฒ ๋๋ฉด์๋, ํ๋ฅ ๋ถํฌ์์๋ถํฐ sampling ํ๋ ํจ๊ณผ๋ฅผ ๋ณผ ์ ์๋ ๊ฒ์
๋๋ค.
๊ฐ์ฅ ๊ฐ๋จํ ์์๋ก q_phi ๊ฐ gaussian ์ธ ๊ฒฝ์ฐ mean mu์ variance sigma ๋ฅผ ์ด์ฉํด์ mu + sigma*epsilon ๊ณผ ๊ฐ์ด ํ๋ฅ ๋ณ์๋ฅผ ์ฌ์ฐฝ์กฐํ๋ ๊ณผ์ ์ธ ๊ฒ์
๋๋ค.
์ต์ข
์ ์ผ๋ก๋ ์ ๊ทธ๋ฆผ์์ ๋ณด์ด๋ ๋ฐ์ ๊ฐ์ด Monte Carlo Expectation ์ ํตํด ์ผ๋ฐ์ ์ผ๋ก ํจ์ f(z) ๋ฅผ reparametrization ์ ํตํด ๋ณํํ๋ ๊ณผ์ ์ ๊ธฐ์กด ์์๋ค๊ฐ ์ ์ฉํด์ lower bound function ์ ๋ณํํ ์ ์์ต๋๋ค. ์ ์์ L^A, L^B ์ ๊ฒฝ์ฐ์๋ lower bound ์์ด ์์์ ๋ ๊ฐ๋ฅผ ๋ง์ ๋๋ ธ๋๋ฐ ๊ฐ๊ฐ์ ๋ํด์ ์ฌ์ฉํ ๊ฒ์ด๋ฉฐ, ์ต์ข
์ ์ผ๋ก๋ ์ด๋ ๊ฒ์ ์ฌ์ฉํด๋ ๋ฌด๋ฐฉํฉ๋๋ค. ๋ํ minibatch ๋ก ํ์ตํ๊ฒ ๋ ๊ฒฝ์ฐ์๋ ์ ์ฒด ๋ฐ์ดํฐ์
์ฌ์ด์ฆ์ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ๊ทธ๋ฆผ์ฒ๋ผ lower bound ๋ฅผ ์์ ํด์ค๋๋ค.
Choosing Reparametrization Function
๋์ผ๋ก ๋
ผ๋ฌธ์์๋ reparametrization function ์ ์ด๋ป๊ฒ ํ๋ฉด ์ ์ค์ ํ ์ ์๋๊ฐ์ ๋ํ ํ์ ์๋ ค์ฃผ๊ณ ์์ต๋๋ค. ๊ฐ๋จํ๊ฒ๋ง ์๊ฐ ๋๋ฆฌ์๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
1.
Inverse CDF ๋ก ์ ์ํ ๋ค, epsilon ์ uniform(0,1) ํ๋ฅ ๋ถํฌ์์ sampling ํ์ฌ ๊ณ์ฐํด๋ธ๋ค.
2.
Gaussian ๊ณผ ์ ์ฌํ ๊ฒฝ์ฐ, location-scale ์ ์ค์ ํ์ฌ location + scale*epsilon ์ผ๋ก ์ ํ๊ณ , epsilon ์ N(0,1) ์์ sampling ํ์ฌ ๊ณ์ฐํด๋ธ๋ค.
3.
๋ค๋ฅธ ๋ณด์กฐ ๋ณ์๋ค์ ๋ณํ์ผ๋ก ๊ณ์ฐํด๋ธ๋ค.
๋ฑ์ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ค๊ณ ํฉ๋๋ค. ์ด ๋ถ๋ถ์ ์ฐธ๊ณ ๋ง ํ์
๋ ์ข์ ๊ฒ ๊ฐ์ต๋๋ค.
Experiments
๋
ผ๋ฌธ์์ ์งํํ ์คํ์ผ๋ก ๋
ผ๋ฌธ์ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ ํ์ต๊ณผ ๊ธฐ์กด์ Wake-Sleep ๋ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ก๋ ํ์ต๊ณผ์ ๋น๊ต๋ฅผ ์งํํ์ต๋๋ค.
์ ๊ทธ๋ฆผ์์ ๋ณด์ด๋ ๊ฒ ์ฒ๋ผ latent space dimension์ด ๋ฐ๋๊ณผ ๊ด๋ จ์์ด ์ ์ฒด์ ์ผ๋ก marginal (log) likelihood ๊ฐ ๋
ผ๋ฌธ์์ ๋ณด์ฌ์ค ์๊ณ ๋ฆฌ์ฆ์ ํ์ฉํ์ ๋ ๋ ๋๊ฒ ๋ํ๋จ์ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก Monte Carlo Estimation ๊ณผ ํจ๊ป training dataset size ์ ๋ฐ๋ฅธ ๋น๊ต๋ ์งํํ์ต๋๋ค. Monte Carlo Estimation ์ ๊ฒฝ์ฐ online algorithm ์ด ์๋๊ธฐ ๋๋ฌธ์ ์ ์ฒด ๋ฐ์ดํฐ์
์ ๊ฐ์ง๊ณ ์์ํ์ง ์์ผ๋ ํจ๊ณผ๊ฐ ๋จ์ด์ง๋ ๊ฒ์ ๋ณด์๊ณ , Wake-Sleep ๊ณผ ๋
ผ๋ฌธ์ ์๊ณ ๋ฆฌ์ฆ ์ฌ์ด์์ ๋
ผ๋ฌธ์ ์๊ณ ๋ฆฌ์ฆ์ ๋์ marginal (log) likelihood ๋ฅผ ๋ณด์์์ ๋ด์ธ์ ์ต๋๋ค.
Conclusion
์ด๊ฒ์ผ๋ก ๋
ผ๋ฌธโAuto-Encoding Variational Bayesโ์ ๋ด์ฉ์ ๊ฐ๋จํ๊ฒ ์์ฝํด๋ณด์์ต๋๋ค.
์ค๋๋ง์ ์ด๋ก ์์ฃผ์ ๋
ผ๋ฌธ์ ์ฝ์ด์ ์ ์ ํ๋๋ฐ, ์๊ฐ๋ณด๋ค ๋ฐฐ๊ฒฝ์ง์์ด ๋ง์ด ํ์ํด์ ์ ์ ๋
ผ๋ฌธ์ ๋ค์ฌ๋ค๋ณด๋ ์๊ฐ๋ณด๋ค ๋ฐฐ๊ฒฝ์ง์์ ๊ณต๋ถํ๋ ์๊ฐ์ด ์กฐ๊ธ ๋ ๋ง์๋ ๊ฒ ๊ฐ์ต๋๋ค. ๊ทธ๋์์ธ์ง ๋
ผ๋ฌธ์ ์ฝ๊ณ ์ ์ฒด์ ์ธ ํต๊ณํ์ ๋ฐฐ๊ฒฝ ์ง์์ ์์ด์ ์ฑ์ฅํ ์ ์์๋ค๋ ๋๋์ด ๋ง์ด ๋ค์์ต๋๋ค.
๊ฐ์ธ์ ์ผ๋ก๋ ์์์ผ ํ๋ ์ง์์ ์ฐ๋๋ฏธ์ ๋ฌปํ๋ ๋๋์ด ๋ค์ด ์ฝ๊ธฐ์ ์กฐ๊ธ ๋ถ๋ด์ด ๋์์ง๋ง ์์๊ฐ๋๊ฒ ๋ง์๋ ๋
ผ๋ฌธ์ด์๋ ๊ฒ ๊ฐ์ต๋๋ค . ์ฌ๋ฌ๋ถ๋ค๋ ์์ฑ ๋ชจ๋ธ์ ๊ด์ฌ์ด ์์ผ์๋ค๋ฉด, ๊ทธ๋ฆฌ๊ณ ์ ๋๋ก ๊ทธ ๋ฐ์ ์ ์๊ณ ์ถ๋ค๋ฉด ๊ผญ ํ ๋ฒ ์ฝ์ด๋ณด์๋ฉด ์ข์ ๊ฒ ๊ฐ์ต๋๋ค.