来源:paper, code,CVPR 2019, 韩国高等科技研究所&蒙特利尔大学

Abstract

  • EGNN: edge-labeling graph neutral network
  • 学习预测边的label,使得网络可以利用对于类内相似性和类间不相似性的直接连接状态的探索,通过迭代更新边的标签,使得聚类表示是可以进化的
  • 适合在各种数量的类别之间迁移并且不需要重新训练

Introduction

  • meta-learning:few-shot learning, learn-to-learn, non-stationary reinforcement learning, continual learning
  • 使用GNN和深度神经网络来解决富连接结构的数据
    • GNN可以通过消息传递迭代地完成邻居的数据聚合,从而表述数据实体之间的复杂关系。
    • 小样本算法需要更加全面的发掘support和query之间的关系
  • 现有的方法
    • Few-shot learning with graph neural networks:先建立一个从support到query的全连接的图,节点使用嵌入向量和独热编码label表示,通过邻居聚合迭代的更新节点feature,完成对query的分类。
    • Transductive propagation network for few-shot learning(TPN):在通过深度神经网络得到的节点特征上进行传导传播,测试过程中迭代性的在整个support和query实例上传播独热编码label【?】
    • 以上方法都是隐式的建模类内的相似性和类间的不相似性
  • 边标注模型显式聚类的方法,结合表征学习和度量学习;并不需要提前指定类别数目。
  • 本文的模型由多层组成,每一层都包括一个节点更新block和一个边更新block。边的label通过最后一次的label feature得到。EGNN可以实现对所有query数据整体的推导式传播。
  • contribution
    • EGNN首次使用边标注的方式构建
    • 在少样本分类的有监督和半监督任务中,表现超过了所有的GNN。同时,证明显式聚类和分别利用类内相似、类间不同是有效的。

Edge-Labeling Graph

  • 相关性聚类分析(Correlation clustering)是一个通过同时最大化类间不同于类内相同来实现边标注推理的图分割算法。【听起来有些类似支持向量机?】
  • 图注意力网络中的注意力机制最近被扩展到加到实数边特征,既适合局部内容,也适合全局层次的引文网络建模(citation network,社交网络的一种)

Few-shot learning

  • 表征学习:同归最近邻表征之间的相似性完成预测
  • meta-learner:学习优化模型中的参数提取出一些可以在任务之间使用上下文转移的知识(Meta-LSTM, MAML, Reptile, SNAIL)

Method

Problem definition: Few-shot classification

  • 在每个类只有很少样本的情况下学习一个分类器
  • 每一个few-shot分类任务T\mathcal{T}包括support set S\mathcal{S}和query set Q\mathcal{Q}
  • episodic training:在training task中抽样,模拟少样本测试时的场景
    • T=SUS={(xi,yi)}i=1N×K and Q={(xi,yi)}i=N×K+1N×K+T\mathcal{T}=\mathcal{S}\cup\mathcal{U}\text{, }\mathcal{S}=\{(x_{i},y_{i})\}_{i=1}^{N\times{K}}\text{ and }\mathcal{Q}=\{(x_{i},y_{i})\}_{i=N\times{K}+1}^{N\times{K}+T}
    • xi,yi{C1,...,CN}=CTCx_{i},y_{i}\in\{C_{1},...,C{N}\}=\mathcal{C}_{\mathcal{T}}\subset\mathcal{C},同时CtrainCtest=\mathcal{C}_{train}\cap\mathcal{C}_{test}=\emptyset

Model

  • 前提:通过卷积神经网络提取所有样本的特征

  • 图构建:G=(V,E;T)\mathcal{G}=(\mathcal{V},\mathcal{E};\mathcal{T}) 每一个节点代表一个样本,全连接图,每一条表代表一种关系种类

    • ground truth: edge-label

      yij={1,if yi=yj,0,otherwise.y_{ij}= \begin{cases} 1, & \text{if }y_{i}=y_{j},\\ 0, & \text{otherwise.} \end{cases}

    • 边特征eij={eijd}d=12\mathbf{e}_{ij}=\{e_{ijd}\}_{d=1}^{2} 一个二维向量,表示所连接的两个节点的intra-和inter-关系【这是同时表征类间不同与类内相同的来源】

      eij0={[10],if yij=1 and i,jN×K,[01],if yij=0 and i,jN×K,[0.50.5],otherwise\mathbf{e}_{ij}^{0}= \begin{cases} [1||0],& \text{if }y_{ij}=1\text{ and }i,j\leq{N\times{K}},\\ [0||1],& \text{if }y_{ij}=0\text{ and }i,j\leq{N\times{K}},\\ [0.5||0.5],& \text{otherwise} \end{cases}

    • 节点特征vi0=femb(ei;θemb)\mathbf{v}_{i}^{0}=f_{emb}(e_{i};\theta_{emb}),其中θemb\theta_{emb}表示参数集合
  • 传播,从1\ell-1层得到的特征vi1 and eij1\mathbf{v}_{i}^{\ell-1}\text{ and }\mathbf{e}_{ij}^{\ell-1}

    • 节点特征(\ell层):首先接受“邻居聚合”处理\to然后进行特征转换

    • 边特征(1\ell-1层)

      • 首先被用作度量邻居节点关系,与注意力机制相似

        vi=fv([je~ij11vj1e~ij21vj1];θv)\mathbf{v}_{i}^{\ell}=f_{v}^{\ell}\left(\bigg[\sum_{j}\tilde{e}_{ij1}^{\ell-1}\mathbf{v}_{j}^{\ell-1}||\tilde{e}_{ij2}^{\ell-1}\mathbf{v}_{j}^{\ell-1}\bigg];\theta_{v}^{\ell}\right)

        其中,e~ijd=eijdkeikd\tilde{e}_{ijd}=\frac{e_{ijd}}{\sum_{k}e_{ikd}}fv(θ)f_{v}^{\ell}(\theta)就是转移函数【从描述来看本文将相同/不同关系作为了一种通道处理,统一建模】

      • 然后利用更新的节点表征更新边的表征【相同/不同的表征更新方式一致】

        其中,fe(θ)f_{e}^{\ell}(\theta)表示相似度的计算函数

      • 从更新过程可以看出,边的更新不光考虑到所连接的节点的影响,也考虑了其他节点对的影响

  • 输出结果:节点属于某个集合的分布P(yi=CkT)=pi(k)P(y_{i}=\mathcal{C}_{k}|\mathcal{T})=p_{i}^{(k)}

    pi(k)=softmax({j:ji(xj,yj)S}y^ijδ(yj=Ck))(7)p_{i}^{(k)}=\text{softmax}(\sum_{\{j:j\neq{i}\wedge(x_{j},y_{j})\in\mathcal{S}\}}\hat{y}_{ij}\delta(y_{j}=\mathcal{C}_{k}))\tag{7}

    或者也可以使用图聚类进行分类,但本文中使用Eq.(7)。

  • 算法

Training - loss function

L==1Lm=1MλLe(Ym,e,Y^m,e)\mathcal{L}=\sum_{\ell=1}^{L}\sum_{m=1}^{M}\lambda_{\ell}\mathcal{L}_{e}(Y_{m,e},\hat{Y}_{m,e}^{\ell})

其中,mm是task数,\ell是层数,Le\mathcal{L}_{e}定义为二值的交叉熵loss

Experiments

本文使用miniImageNet、tieredImageNet数据集做比较

Evaluation

5-way 5-shot,测试时从5个类中随机抽取15 quries,取600个从test集随机生成的episodes的平均值

结果

从结果中看出增长效果啊哈似乎比较明显的,其中Trans为:

  • “No” means nontransductive method, where each query sample is predicted independently from other queries
  • “Yes” means transductive method where all queries are simultaneously processed and predicted together【所以这个transductive到底表示什么? 产生相互作用的方式?】
  • “BN” means that query batch statistics are used instead of global batch normalization parameters, which can be considered as a kind of transductive inference at test-time

Ablation studies

  • 增加layer可以使得结果有提高
  • 使用inter&intra是对结果有帮助的
  • EGNN有更好的可扩展性,尤其是当训练设置与测试设置不一致时也能给出很好地结果
  • t-SNE结果【这是个啥?】,可以看到相比GNN,EGNN可以很好地将query分开