[link]
This paper aims to learn a sparse semantic representation of texts and images instead of a dense representation trained by CLIP or ALIGN. The sparse embeddings are achieved by: (1) For an input (image or text), extract it to a feature (using Transformer) $h$ where $h_{j}$ corresponds to the $jth$ word in the input. (2) Each $j$ word embedding will be transformed to $p(h_{j})$ in vocabulary space $V$ by using a mapping function (in this paper, this is BERT Masked Language Model MLM). So each $p(h_{j})$ is a token in a vocabulary space $V$. (3) A max pooling layer will be applied to $p(h_{j})$ to get a value denoted for that token. So in the end, we will have a sparse vector living in V-dimensional space. https://i.imgur.com/BTvndLR.png Training: To achieve two goals (1) aligning text and images in the sparse embedding and (2) grounding the sparse vector with the human-understandable word in the vocabulary, they proposed 3-stage training: Stage 1: Training image embedding with masked tokens. In the first stage, they co-trained both the image and text encoders and apply a binary mask on the text embedding. By matching with the masked text embedding, the image encoder is learned to ground its image embedding on the tokens from the pairing text. Therefore, after the stage 1 training, the image embedding is living in the vocabulary’s interpretable space. Stage 2: Training with frozen image encoder. In this stage, they focus on grounding the text embedding to the same interpretable space where the image embedding is trained to reside in from stage 1. The key idea is to let the image encoder teach the text encoder as a teacher model. After stage 2 training, both image and text embeddings are in the same human-interpretable embedding space constructed by the vocabulary. Stage 3: Fine-tuning both encoders, they boosted the image-text matching performance by finetuning both encoders jointly. https://i.imgur.com/PWrEbkk.png To further encourage the sparsity, they proposed to use FLOPs regularization loss such that only a small number of token embeddings in V are non-zeros.
Your comment:
|