A Unified Masked Jigsaw Puzzle Framework for Vision and Language Models
Weixin Ye 1, Wei Wang 1, Yahui Liu 2, Yue Song 3, Bin Ren 4, Wei Bi 5, Rita Cucchiara 6, Nicu Sebe 4
Published on arXiv
2601.12051
Model Inversion Attack
OWASP ML Top 10 — ML03
Key Finding
MJP reduces gradient-based input reconstruction from Position Embedding gradients while also improving classification and sentiment analysis accuracy on ImageNet-1K, Yelp, and Amazon benchmarks.
Masked Jigsaw Puzzle (MJP)
Novel technique introduced
In federated learning, Transformer, as a popular architecture, faces critical challenges in defending against gradient attacks and improving model performance in both Computer Vision (CV) and Natural Language Processing (NLP) tasks. It has been revealed that the gradient of Position Embeddings (PEs) in Transformer contains sufficient information, which can be used to reconstruct the input data. To mitigate this issue, we introduce a Masked Jigsaw Puzzle (MJP) framework. MJP starts with random token shuffling to break the token order, and then a learnable \textit{unknown (unk)} position embedding is used to mask out the PEs of the shuffled tokens. In this manner, the local spatial information which is encoded in the position embeddings is disrupted, and the models are forced to learn feature representations that are less reliant on the local spatial information. Notably, with the careful use of MJP, we can not only improve models' robustness against gradient attacks, but also boost their performance in both vision and text application scenarios, such as classification for images (\textit{e.g.,} ImageNet-1K) and sentiment analysis for text (\textit{e.g.,} Yelp and Amazon). Experimental results suggest that MJP is a unified framework for different Transformer-based models in both vision and language tasks. Code is publicly available via https://github.com/ywxsuperstar/transformerattack
Key Contributions
- Masked Jigsaw Puzzle (MJP) framework that randomly shuffles token order and replaces position embeddings with a learnable 'unknown' embedding, disrupting gradient-based data reconstruction
- Unified defense applicable to both vision (ViT on ImageNet-1K) and language (BERT on Yelp/Amazon) Transformer models in federated learning settings
- Demonstrates that MJP simultaneously improves robustness against gradient attacks and boosts downstream task performance
🛡️ Threat Analysis
The paper's primary threat model is gradient leakage/inversion in federated learning: an adversary reconstructs private input data from shared gradients of Transformer Position Embeddings. MJP is proposed as a defense that disrupts the spatial information encoded in those gradients, directly mitigating gradient-based data reconstruction attacks.