Variational linear methods for approximating posteriors
| 1 min read
Explanation & example will follow
#
import torch
nn = torch.nn
class LinearVariationalNormal(nn.Module):
def __init__(self, in_features, out_features, bias=True,
make_weights_prior=None, make_bias_prior=None, **kwargs):
super().__init__()
w_loc = torch.zeros((in_features, out_features))
w_scale = torch.ones((in_features, out_features))
self.w_loc = nn.parameter.Parameter(w_loc, requires_grad=True)
self.w_scale = nn.parameter.Parameter(w_scale, requires_grad=True)
if make_weights_prior is None:
make_weights_prior = lambda in_,out_: torch.distributions.Normal(loc=torch.zeros((in_, out_)),
scale=torch.ones((in_, out_)))
self.w_prior = make_weights_prior(in_features, out_features)
if bias:
b_loc = torch.zeros((out_features))
b_scale = torch.ones((out_features))
self.b_loc = nn.parameter.Parameter(b_loc, requires_grad=True)
self.b_scale = nn.parameter.Parameter(b_scale, requires_grad=True)
if make_bias_prior is None:
make_bias_prior = lambda out_: torch.distributions.Normal(loc=torch.zeros((out_)),
scale=torch.ones((out_)))
self.b_prior = make_bias_prior(out_features)
def forward(self, x, kl_weight=1., **forward_kws):
q_w = torch.distributions.Normal(loc=self.w_loc, scale=self.w_scale) # Normal weights posterior
weights = q_w.rsample()
kl_loss = self.kl_loss(q_w, self.w_prior, kl_weight)
if hasattr(self, b_loc):
q_b = torch.distributions.Normal(loc=self.b_loc, scale=self.b_scale) # Normal bias posterior
bias = q_b.rsample()
kl_loss += self.kl_loss(q_b, self.b_prior)
return x@weights.T + bias, kl_loss
return x@weights.T, kl_loss
def kl_loss(q, r, kl_weight=1.):
return kl_weight*torch.distributions.kl.kl_divergence(q, r)