Variational linear methods for approximating posteriors

| 1 min read

Explanation & example will follow


#
import torch
nn = torch.nn

class LinearVariationalNormal(nn.Module):
    def __init__(self, in_features, out_features, bias=True, 
                 make_weights_prior=None, make_bias_prior=None, **kwargs):
        super().__init__()
        w_loc = torch.zeros((in_features, out_features))
        w_scale = torch.ones((in_features, out_features))
        self.w_loc = nn.parameter.Parameter(w_loc, requires_grad=True)
        self.w_scale = nn.parameter.Parameter(w_scale, requires_grad=True)
        if make_weights_prior is None:
            make_weights_prior = lambda in_,out_: torch.distributions.Normal(loc=torch.zeros((in_, out_)),
                                                                             scale=torch.ones((in_, out_)))
        self.w_prior = make_weights_prior(in_features, out_features)
        if bias:
            b_loc = torch.zeros((out_features))
            b_scale = torch.ones((out_features))
            self.b_loc = nn.parameter.Parameter(b_loc, requires_grad=True)
            self.b_scale = nn.parameter.Parameter(b_scale, requires_grad=True)
            if make_bias_prior is None:
                make_bias_prior = lambda out_: torch.distributions.Normal(loc=torch.zeros((out_)),
                                                                          scale=torch.ones((out_)))
            self.b_prior = make_bias_prior(out_features)

    def forward(self, x, kl_weight=1., **forward_kws):
        q_w = torch.distributions.Normal(loc=self.w_loc, scale=self.w_scale) # Normal weights posterior
        weights = q_w.rsample()
        kl_loss = self.kl_loss(q_w, self.w_prior, kl_weight)
        if hasattr(self, b_loc):
            q_b = torch.distributions.Normal(loc=self.b_loc, scale=self.b_scale) # Normal bias posterior
            bias = q_b.rsample()
            kl_loss += self.kl_loss(q_b, self.b_prior)
            return x@weights.T + bias, kl_loss
        return x@weights.T, kl_loss

    def kl_loss(q, r, kl_weight=1.):
        return kl_weight*torch.distributions.kl.kl_divergence(q, r)