PyTorch 中的 Unpooling 操作
Average Unpooling
Pytorch 中并没有直接实现 AverageUnpooling 的 layer,但是 pool 操作本身没有参数,因此可以认为是完全针对 Function 的再封装。通过 F.interpolate 操作可以实现类似 AverageUnpooling 的操作。参考 issue
其中,F.interpolate 函数的定义为:
1 | torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None) |
插值方式参数包括 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' 六种。对应于AvgPool 方式还原的参数应该使用 area 算法进行插值,实际使用效果为:
1 | data = torch.randn([1,1,8,9]) |
首先我们使用 kernel=2*3 的 AvgPool 操作得到池化之后的矩阵
1 | unpool = torch.nn.functional.interpolate(pooled, size=(8,9)) |
然后我们使用 interpolate 函数进行了还原,每一个 2*3 大小的块都被填充上相同的值,填充的方式和池化被计算的方式是一样的。至此,我们便完成了池化操作的还原。
在某些情况下,池化操作可能还涉及到 padding,或者 kernel size 无法整除的情况,这时就需要按照上述算法还原之后通过裁剪等操作进行一些后处理,才能完全还原池化前的矩阵形状。
Max Unpooling
Pytorch 官方 doc 中给出了这个 layer 的实现。[see](https://pytorch.org/docs/stable/nn.html?highlight=max unpool#torch.nn.MaxUnpool2d)
1 | torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0) |
因为数据损失的问题,MaxPool 操作也无法完全还原,只能保证最大值点被正确填充,而且在进行 pool 操作的时候要求要保留 indices 的数据;其余点会使用零值填充,可以参考官方示例:
1 | pool = nn.MaxPool2d(2, stride=2, return_indices=True) |
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Flymin's Blog!
