9.4 Correct Gradient Estimators with DiCE

이번 section에서는 이를 모두 해결할 Infinitely Differentiable Monte-Carlo Estimator(DiCE)를 소개합니다. 이는 SCG에서 어떤 차수의 미분도 정확하게 계산할 수 있는 실용적인 알고리즘입니다. 특정 차수의 미분을 하기 위해 가장 간단한 방법은 9.3.1에 나온 방법을 재귀적으로 계속 사용하면 되겠지만, 이는 두 가지 결점을 가지고 있습니다. 첫번째로 gradient를 이렇게 정의하는 것이 auto-diff library에 적용하기 힘들다는 점입니다. 둘째로, 단순하게 gradient estimator를 구하면 θf(x;θ)g(x;θ) \nabla_{\theta}f(x;\theta) \neq g(x;\theta)이기 때문에 제대로 업데이트되지 않습니다.

시작전에 앞에서 정의한 것과 같이 L=E[cCc] \mathcal{L} = \mathbb{E}[\sum_{c\in\mathcal{C}}c]를 SCG에서의 objective로 정의하고 시작합니다. 이때 모든 의존성을 만족하는 gradient estimator는 다음과 같이 표현할 수 있습니다.

θL=E[cC(cwWcθlogp(wDEPSw)+θc(DEPSc))](1)\nabla_\theta\mathcal{L} = \mathbb{E}[\sum_{c\in\mathcal{C}}(c\sum_{w\in \mathcal{W}_c}\nabla_\theta\log p(w|\mathrm{DEPS}_w)+ \nabla_\theta c(\mathrm{DEPS}_c))] \cdots (1)

Wc \mathcal{W}_c는 stochastic nodes에 속하고, cost nodes에 영향을 끼치면서 θ\theta에 영향을 받는 모든 node를 의미합니다.ancestors node에 잘 조건화되었다고 가정하고 이제부터 DEPS의 표기에 대해 생략해서 표기하겠습니다.

이전 부터 소개했지만, DiCE에서는 높은 차수의 미분을 정확하게 하기 위해 MagicBox \square라는 operator를 사용하고, input으로는 stochastic nodes W\mathcal{W}, 그리고 아래와 같은 두가지 성질을 가지고 있습니다.

  • (W)1\square (\mathcal{W}) \rightarrow 1

  • θ(W)=(W)wWθlog(p(w;θ))\nabla_{\theta}\square(\mathcal{W})=\square (\mathcal{W})\sum_{w \in \mathcal{W}}\nabla_{\theta}\log(p(w;\theta))

첫번째 성질의 \rightarrow는 평가한다(evaluates to)라는 의미로 모든 gradient의 같음을 의미하는 full equality(=)와는 대조적입니다. auto-diff에서는 이를 forward pass evaluation의 의미로 사용합니다.

두번째 성질은 \square를 사용해서 sample이 어디서 sampling됐는지 그 분포에 대한 의존성을 보입니다.(ww에 대한 확률 합 형태가 됩니다.) 그리고 미분하면 log likelihood trick을 이용해 log형태로 나타난 것입니다. 이는 이 성질을 만족하면 첫번째 성질은 쉽게 만족할 수 있습니다. (총 확률 합이 1이므로)

두번째 특성을 만족한다면, L=E[cCc] \mathcal{L} = \mathbb{E}[\sum_{c\in\mathcal{C}}c] 인 objective에 대해 다음같이 표현할 수 있습니다.

L=cC(Wc)c  ((Wc)1) \mathcal{L}_\square = \sum_{c\in\mathcal{C}}\square(\mathcal{W}_c)c \ \ (\because \square(\mathcal{W}_c) \rightarrow 1)

L\mathcal{L}_{\square}을 가지고 어떻게 정확하게 고차미분을 할 수 있는지에 대해 증명해보겠습니다.

Theroem1.      E[θnL]θnL,n{0,1,2,}\bm{\mathrm{Theroem 1.\ \ \ \ \ \ }} \mathbb{E}[\nabla^n_{\theta}\mathcal{L}_\square] \rightarrow \nabla^n_\theta\mathcal{L},\forall n \in \{0,1,2,\cdots \}

모든 cost nodes cCc \in \mathcal{C}에 대해 다음과 같이 정의하겠습니다.

     c0   =      cE[cn+1]=θE[cn]\ \ \ \ \ c^0 \ \ \ = \ \ \ \ \ \ c \\ \mathbb{E}[c^{n+1}] = \nabla_\theta\mathbb{E}[c^n]

cnc^n는 objective E[c]\mathbb{E}[c]의 n차 미분값입니다.

다음으로 cnc^n_{\square} cn(Wcn)c^n\square(\mathcal{W}_{c^n})인데, magicbox operator의 첫번째 특성으로 인해, Wcn \square \mathcal{W}_c^n은 1이 되어, cncnc^n_\square \rightarrow c^n임을 알았습니다. 이를 통해, cnc^n_\square또한 objective의 n번째 미분값이랑 같다는 의미가 됩니다. 그렇다면, 마지막으로 θcn=cn+1\nabla_{\theta}c^n_\square = c^{n+1}_\square임을 보이면 n차 미분 전체에 대해 magicbox operator로 구할 수 있고, 그것이 실제 미분값과 같다는 의미가 됩니다.

θcn=θ(cn(Wcn))\nabla_\theta c^n_\square = \nabla_\theta(c^n\square(\mathcal{W}_{c^n}))

=cnθ(Wcn)+θ(Wcn)cn= c^n\nabla_\theta\square(\mathcal{W}_{c^n})+ \nabla_\theta (\mathcal{W}_{c^n}) \square c^n

=cn(Wcn)(wWcnθlog(p(w;θ)))+(Wcn)θcn= c^n\square(\mathcal{W}_{c^n})(\sum_{w\in\mathcal{W}_{c^n}}\nabla_\theta\log(p(w;\theta)))+ \square(\mathcal{W}_{c^n}) \nabla_\theta c^n

=(Wcn)(θcn+cnwWcnθlog(p(w;θ)))(9.4.4)= \square(\mathcal{W}_{c^n}) (\nabla_\theta c^n+c^n\sum_{w\in\mathcal{W}_{c^n}}\nabla_\theta\log(p(w;\theta))) \cdots (9.4.4)

(Wcn+1)cn+1=cn+1(9.4.5) \square(\mathcal{W}_{c^{n+1}})c^{n+1} = c^{n+1}_\square \cdots (9.4.5)

이 때, (9.4.4)에서 (9.4.5)로갈 때, 두가지 테크닉이 필요합니다. 첫번째로, L=E[cn] \mathcal{L} = \mathbb{E}[c^n]의 형태를 본문 위(1)형태로 변환해 사용하는 것입니다. 그렇게 되면 다음과 같이 표현할 수 있습니다.

cn+1=θcn+cnwWcnθlogp(w;θ) c^{n+1} = \nabla_{\theta}c^n + c^n \sum_{w \in \mathcal{W}_{c^n}}\nabla_\theta \log p(w;\theta)

이를 자세히 보면 (9.4.4)의 표현과 같음을 알 수 있습니다. 둘 째로, Wcn\mathcal{W}_{c^n}Wcn+1\mathcal{W}_{c^{n+1}}은 같은 stochastic nodes를 가리키고있을 것이므로, Wcn=Wcn+1\mathcal{W}_{c^n} =\mathcal{W}_{c^{n+1}}이 자명합니다.

Last updated