SRNTT: Image Super-Resolution by Neural Texture Transfer - Introduction and Implementation
こんにちは,さしすせ (Twitter, GitHub) と申します.初めて個人のブログを稼働させます. 現在修士課程に在学しており,大学・インターン先各所にて深層学習を用いた画像認識・生成の研究開発を行なっております. 本エントリでは,CVPR2019にて発表されたRefSR手法であるSRNTT (Image Super-Resolution by Neural Texture Transfer) [*1][*2]について解説・実装を行います.図は論文から引用しています.実装はPyTorchで行い,下記に置いてあります.
本エントリは,インターン先であります Navier株式会社 (ナビエ) にて取り組んでいる内容の一部となります. 弊社では,深層学習を用いた画像変換技術の開発およびソフトウェアの提供を行なっており,特にLow-Level Vision周りの技術に力を入れております. 本エントリをお読みになった方で興味をお持ちの方がいらっしゃいましたら,気軽にお問い合わせいただけると幸いです.
TL; DR
- RefSRは,入力画像に加えて参照画像を用いる超解像手法.
- SRNTTは,RefSRの制約を緩めて広く利用可能にした手法.
- PyTorchによる再現実装を行い,コードを公開した
目次
超解像
超解像 (Super-Resolution, SR) とは,入力信号の解像度を高めて出力する技術の総称であり,端的に言うと画像を大きくする技術です. 超解像は下図のようなImage Restoration 問題の一種と捉えられます.
Image Restoration 問題は,観測される低解像度画像 が高解像度画像 が劣化 を通って生成されると仮定し, の逆変換 を求めることで画像を再構成する問題です. Image Restoration 問題は典型的な逆問題であり,仮に が既知であっても から を一意に定めることは困難です. このような問題を ill-posed問題 といい,超解像においては拡大倍率が大きくなるほどill-posed性はきつくなっていきます.
近年,超解像の解法として畳み込みニューラルネットワーク(CNN)が多く用いられます. CNNベースの超解像のうち,単一のLR画像のみを入力とするSingle Image Super-Resolution (SISR)は,特に競争が激化している分野と言えます. 数多くのネットワーク構造や知覚品質を向上させるための学習方法などがここ数年で提案されてきましたが,上述のill-posed性により高倍率の復元にはまだまだ 難があり,単純なSISRでの精度向上には限界が見えつつあります.
RefSR
RefSR (Reference-based Super-Resolution) は,単一の画像を用いるSISRとは対照的に,参照画像を用いた超解像手法です. 画像検索などを使って取得できる高解像度な類似画像を参照画像とし,LR画像と似た部分を抽出して貼り合せることで画像を復元するようなイメージです. この考え方は結構本質的で,VISTA-Vision[*3]やA+[*4]などでNon-deep時代から検討されていますが,当時は画像というよりはパッチ単位で考えていたため,近年言われるようなセマンティックな部分をあまり考慮できていなかった側面があります*5. 本エントリでは,画像全体を参照するものをRefSRと呼ぶこととし,以下で関連研究を2つ例示します.
Landmark Image Super-Resolution by Retrieving Web Images [*6]
本手法(以下,Landmark)は局所特徴量をベースとした画像検索とレジストレーションによって構成された超解像手法です.Landmarkの処理の流れを次に示します.
まず,低解像度画像 をBicubic補間によって任意のサイズに拡大して補間画像 を生成します. 次に, に対してSIFT[*7]特徴量 を計算し,の Bag of Visual Words (BoVW) 表現を得ます. ここで求めたBoVW表現を用いて画像検索を行い,類似画像を複数枚ピックします. 類似画像においてもSIFT特徴量は計算済みであるため,この段階でキーポイントマッチングとRANSACによって射影変換行列を計算し,画像のレジストレーションを行います. 最後にパッチマッチングで微小なズレを補正し,エネルギー最小化によってパッチをブレンドすることで,出力画像 を生成します.
CrossNet: An End-to-end Reference-based Super Resolution Network using Cross-scale Warping [*8]
Landmarkなどの従来手法では,参照画像を取得した後にレジストレーションやパッチマッチングを独立に行うため,最適化が難しいなどの課題がありました. CrossNetではこれらの処理をCNN内で行い,end-to-endでの学習を可能にしました.
CrossNetへの入力画像は学習済みモデルによって事前に拡大しておきます*9. メインのネットワークはEncoder-Decoder構造をしており,Decoderで参照画像の特徴を利用して画像を復元します. ここで用いる参照画像の特徴は,Flow estimatorが求めたOptical flowによって変形され,Decoderに渡されています. これにより位置合わせがCNN内で完結し,end-to-end学習による安定した最適化と高精度化を達成しました*10.
SRNTT
前置きが長くなりましたが,ここから本題となります.
上述のように,LandmarkやCrossNetは参照画像を幾何的変換によって変形して利用します. 言い換えると,利用可能な参照画像は幾何的変換によって位置合わせできる画像に限定されます. 建造物や風景画像を除けば,このような条件での画像収集は難しいため,RefSRの利用範囲は狭いものとなっていました.
SRNTTでは,幾何的な構造を意識せずに参照画像を利用し,RefSRの制約を 「似たものが映っていれば利用可能」 くらいに緩めることを可能にしました. 以下では,SRNTTの具体的なネットワーク構造や参照画像の利用方法について,実装を交えながら解説します.
ネットワーク構造
SRNTTの構造を次に示します.
何やら複雑ですが,SRNTTは主に2つのブランチで構成されています. 一方は図中下のFeature swappingブランチで, と の特徴のマッチング・交換を行い,Swapped feature を計算します. 他方は図中上のTexture transferブランチで, を利用して のテクスチャを転写し,画像を再構成していきます.
Feature swapping
Feature swappingブランチでの処理は,簡単にいうと から と似た部分を探してくる処理になります. と の間には解像度のギャップがあるため,直接これらを比較することはできません. また,Texture transferブランチにマルチスケールな特徴を渡すため,拡大後の解像度空間で処理を行う必要があります. したがって,マッチングは を拡大した と を直列に縮小・拡大した の間で行い, マッチング結果をもとに の特徴を参照し,Swapped feature を計算することになります*11.
Patch matching
ここでは, と の類似度を特徴空間上のパッチ単位で計算し, の各パッチに最も類似した のパッチを求めます.
まず, と を特徴抽出器 に通し,特徴マップ , を得ます. , をパッチに分解し, それぞれの , 番目のパッチを , とします. 任意の , ペアに対し,次のように内積を計算し,類似度 を計算します. $$ s_{i, j} = \left\langle P_i(\phi(I^{LR\uparrow})), \frac{P_j(\phi(I^{Ref \downarrow \uparrow}))}{||P_j(\phi(I^{Ref \downarrow \uparrow}))||} \right\rangle $$ この計算は に対するカーネル による畳み込みと表現できるため,実際には ごとに類似度マップ を計算します. $$ S_{j} = \phi(I^{LR\uparrow}) \ast \frac{P_j(\phi(I^{Ref \downarrow \uparrow}))}{||P_j(\phi(I^{Ref \downarrow \uparrow}))||} $$ 最後に,位置 における類似度 が最大となるインデックス を計算します. $$ j^\ast = \mathrm{arg} \max_{j} S_j (x, y) $$ これを全ての に対して計算し,マッチング結果とします. このように任意のパッチの組み合わせに対して類似度を計算するため,SRNTTにおいて従来手法のような幾何的な構造の制約はありません.
以上をコードに落とすと,次のようになります.
map_in
は , map_ref_blur
は を表しています.
にはVGG19を用い,relu3_1
の出力に対してPatch matchingを行なっています.
import torch.nn.functonal as F # relu3_1の特徴を抽出, shape: (C, H, W) content = map_in['relu3_1'] condition = map_ref_blur['relu3_1'] # パッチに分解, shape: (C, patch_size, patch_size, n_patches) patch_condition = sample_patches(condition) # 正規化 patch_condition /= patch_condition.norm(p=2, dim=(0, 1, 2)) + 1e-5 # 類似度計算, shape: (1, n_patches, H-(patch_size//2), W-(patch_size//2)) sim = F.conv2d(content.unsqueeze(0), patch_condition.permute(3, 0, 1, 2)) # 類似度最大となるインデックスを計算 max_val, max_idx = sim.squeeze(0).max(dim=0)
理想的には,Patch matchingをTexture transferに渡す全てのスケールに対して行うのですが,論文では高速化のためrelu3_1
でのみPatch matchingを計算し,その結果を用いて以降のTexture swappingに進むとしています.
Texture swapping
ここでは,Patch matchingの結果を元に の特徴を参照していきます. Swapped feature の位置 には $$ P_{\omega(x, y)}(M) = P_{j^\ast}(\phi(I^{Ref})) $$ が参照され,その値が足し込まれていきます.ここで は位置 とPatch matchingの結果を対応づける関数です. ただし, を密にサンプルすると隣のパッチとオーバーラップが発生するので,最後に足し込まれた回数によって平均をとります.
以上をコードに落とすと次のようになります.
Patch matchingのときと同様に,map_in
は , map_ref
は を表しています.
# 3つのスケールに対してTexture swappingを行う swapped_maps = {} for idx, layer in enumerate(['relu3_1', 'relu2_1', 'relu1_1']): # Patch matchingをrelu3_1でしかやっていないため, # relu2_1, relu1_1ではスケールに応じてパッチサイズとストライドを変更 (本来は不要) ratio = 2 ** idx _patch_size = self.patch_size * ratio _stride = self.stride * ratio # ターゲットとなるレイヤの特徴 content = getattr(map_in, layer).squeeze(0) style = getattr(map_ref, layer).squeeze(0) # パッチに分解 (Ref画像は特徴をそのまま使いたいので正規化しない) patches_style = self.sample_patches(style, _patch_size, _stride) # マップを初期化 target_map = torch.zeros_like(content).to(self.device) count_map = torch.zeros(target_map.shape[1:]).to(self.device) # 類似度最大のパッチを加算していく for i in range(max_idx.shape[0]): for j in range(max_idx.shape[1]): _i, _j = i * ratio, j * ratio target_map[:, _i:_i+_patch_size, _j:_j+_patch_size]\ += patches_style[..., max_idx[i, j]] count_map[_i:_i+_patch_size, _j:_j+_patch_size] += 1 target_map /= count_map # 加算回数で平均 # 辞書に追加 swapped_maps.update({layer: target_map.cpu().numpy()})
Feature swappingの処理は以上となります.
先で述べたように,SRNTTにおけるFeature swappingはパッチ単位のマッチングにより位置関係を無視できます.
この特性は複数の参照画像を簡単に利用できることを示唆していて,大規模なデータベースから画像を引いてこれればさらなる精度向上が期待できます.
ただし,特徴マップの任意の位置で類似度を計算する関係上計算コストがかなり高いため,低いレイテンシを求められる状況には向かないという面もあります.
実際に,CUFED5データセット(主な画像サイズ: 500 x 330
)に対してFeature swappingを行うと,推論そのものが 1[s]
ほどなのに対し,Feature swappingは 30[s]
ほどかかります.
そのため,学習データに対してはオフラインでSwapped featureを計算しておく仕様になっています.
Texture transfer
さて,いよいよ画像を拡大するフェーズに来ました. Texture transferブランチは主に2つのコンポーネントで構成されています. 1つ目はContent extactorで一般的な超解像の特徴抽出部分と同じものになります. 2つ目はConditional texture transferで,Swapped featureを使って画像を再構成します.
Content extractor
論文ではContent extractorとしてSRGAN[*12]と同じ構造のネットワークが用いられていますが,
今回はMMSRで提供されている学習済みモデルを利用したかったため,Batch normalizationなどを廃したMSRGANの構造に合わせています.
なおforward()
の出力のうち,h
がContent extractorの出力であり,Conditional texture transferに渡されます*13.
import torch.nn as nn class ContentExtractor(nn.Module): def __init__(self, ngf=64, n_blocks=16): super(ContentExtractor, self).__init__() self.head = nn.Sequential( nn.Conv2d(3, ngf, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True) ) self.body = nn.Sequential( *[ResBlock(ngf) for _ in range(n_blocks)], ) self.tail = nn.Sequential( nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1), nn.PixelShuffle(2), nn.LeakyReLU(0.1, True), nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1), nn.PixelShuffle(2), nn.LeakyReLU(0.1, True), nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True), nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1), ) def forward(self, x): h = self.head(x) h = self.body(h) + h upscale = self.tail(h) return upscale, h class ResBlock(nn.Module): def __init__(self, n_filters=64): super(ResBlock, self).__init__() self.body = nn.Sequential( nn.Conv2d(n_filters, n_filters, 3, 1, 1), nn.ReLU(True), nn.Conv2d(n_filters, n_filters, 3, 1, 1), ) def forward(self, x): return self.body(x) + x
Conditional texture transfer
ここでは,Content extractorの出力とSwapped featureを用いて画像を拡大していきます. 入力特徴を ,そのスケールに対応するSwapped featureを とすると,次のスケールに渡される特徴 は $$ \psi_{l+1} = [ \mathrm{Res}(\psi_{l} || M_{l}) + \psi_{l}]_{\uparrow 2 \times} $$ と表されます. はResidual blocks, はチャネル方向のconcatを表しています. ネットワーク構造的には次のようなります.
倍率が4倍である場合, に対してこの計算を行っていきます.なお, においては拡大演算は行わず出力チャネルに合わせる畳み込みが適用されます.
これをコードにすると次のようになります.
class TextureTransfer(nn.Module): def __init__(self, ngf=64, n_blocks=16): super(TextureTransfer, self).__init__() # for small scale self.head_small = nn.Sequential( nn.Conv2d(ngf + 256, ngf, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True), ) self.body_small = nn.Sequential( *[ResBlock(ngf) for _ in range(n_blocks)], ) self.tail_small = nn.Sequential( nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1), nn.PixelShuffle(2), nn.LeakyReLU(0.1, True), ) # for medium scale self.head_medium = nn.Sequential( nn.Conv2d(ngf + 128, ngf, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True), ) self.body_medium = nn.Sequential( *[ResBlock(ngf) for _ in range(n_blocks)], ) self.tail_medium = nn.Sequential( nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1), nn.PixelShuffle(2), nn.LeakyReLU(0.1, True), ) # for large scale self.head_large = nn.Sequential( nn.Conv2d(ngf + 64, ngf, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True), ) self.body_large = nn.Sequential( *[ResBlock(ngf) for _ in range(n_blocks)], ) self.tail_large = nn.Sequential( nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.1, True), nn.Conv2d(ngf // 2, 3, kernel_size=3, stride=1, padding=1), ) def forward(self, x, maps): # small scale h = torch.cat([x, maps['relu3_1']], 1) h = self.head_small(h) h = self.body_small(h) + x x = self.tail_small(h) # medium scale h = torch.cat([x, maps['relu2_1']], 1) h = self.head_medium(h) h = self.body_medium(h) + x x = self.tail_medium(h) # large scale h = torch.cat([x, maps['relu1_1']], 1) h = self.head_large(h) h = self.body_large(h) + x x = self.tail_large(h) return x
目的関数
Reconstruction loss
ここではピクセルごとのL1距離をReconstruction loss とします. 超解像ではMSEよりL1距離を用いるほうがシャープな結果が得られるとされていてよく用いられます. $$ \mathcal{L}_{rec} = || I^{HR} - I^{SR}||_1 $$
Perceptual loss
ここではJohnson et al. [*14] にならい,VGG特徴のノルムを計算します. 論文ではノルムにフロベニウスノルムが使用されています. $$ \mathcal{L}_{per} = || \phi_i(I^{HR}) - \phi_i(I^{SR})||_F $$
Adversarial loss
WGAN-GP [*15]と同じ損失関数を定義します. 識別機の学習時にはここにGradient penaltyが加算されて学習の安定化を図ります. $$ \quad\quad\quad\quad\quad \mathcal{L}_{adv} = -\mathbb{E}_{\tilde{x} \sim \mathbb{P}_g}[D(\tilde{x})], \\ \min_{G} \max_{D \in G} \mathbb{E}_{x \sim \mathbb{P}_r}[D(x)] -\mathbb{E}_{\tilde{x} \sim \mathbb{P}_g}[D(\tilde{x})], $$
Texture loss
ここではENet [*16] にならい,VGG特徴とSwapped feature のGram行列をマルチスケールで比較します. はスケールを表しています. $$ \mathcal{L}_{t ex} = \sum_l \lambda_{l} || Gr( \phi_{l} (I^{SR}) \cdot S^\ast_l )- Gr(M_l \cdot S^\ast_l)||_F $$ の計算時に最大値を取っている関係上,似たものが写っていない状況では類似度の低い特徴を引いてきている可能性があります. 類似度の高い特徴を優先的に利用するため, における類似度マップ を用いて重み付けを行なった上で比較を行います.
Overall objective
上記の目的関数を線形結合して目的関数 を定義します. ,,は各関数の重みを表していて,論文ではそれぞれ ,,に設定されます. $$ \mathcal{L} = \mathcal{L}_{rec} + \lambda_{per} \mathcal{L}_{per} + \lambda_{adv} \mathcal{L}_{adv} + \lambda_{te x} \mathcal{L}_{te x} $$
実験
再掲になりますが,著者によるTensorFlow実装を参考に,SRNTTをPyTorchによって再現実装を行いました.
以下の実験結果において,SRNTTとして提示する画像は全て再現実装の出力となります.
条件
基本的に論文と同じ条件で学習を行います.
- データセット
- 学習データ: CUFED
- train:val=9:1 にランダムスプリット.
- テストデータ: CUFED5
- 入力画像と類似度順に5枚の画像がセットになっているデータセット.
- 学習データ: CUFED
- 倍率:
- 学習回数
- 事前学習: のみを使って5エポック
- この際,Content extractorの初期値にはMMSRが提供するMSRGANの重みを使用.
- メインの学習: で100エポック
- 事前学習: のみを使って5エポック
- バッチサイズ: 9
- 学習率
- 初期値:
- スケジューラ: 50エポックで
- オプティマイザ: Adam
結果
まずテストセットにおけるPSNRについて,論文が示す結果は次のようになります. HRは正解画像をランダムに変形したものを参照画像とした結果であり,その他は類似画像に対する結果で類似度順にL1からL5まで並んでいます.
これに対し,手元で学習したモデルのPSNRは次のようになりました. おおよそ再現できたのかなという感じです.
Ref画像 | HR | L1 | L2 | L3 | L4 | L5 |
---|---|---|---|---|---|---|
PSNR[dB] | 34.54 | 26.60 | 26.08 | 25.94 | 25.81 | 25.75 |
続いて,出力画像を見ていきます. SRNTT(HR)はHR画像自身を参照画像とした結果,SRNTT(L3)はRef Image(L3)を参照画像とした結果となります. この結果では,入力画像と参照画像に共通してNHL*17のロゴが写っており,SRNTT(L3)ではその部分を転写できていることが確認できます.
さらによく見ると,HR画像のロゴが若干ブレているため,ブレのないL3を参照したほうが結果がシャープになっていることがわかります. このように,質の高い参照画像が手に入る環境では,原画像を超える復元が可能となる場合があります.
次の結果では某夢の国っぽいお城の構造が転写できることが確認できます.
最後に少し面白い結果ですが,下段は上段を参照画像とした際の出力です. 祖母が写っているL3では口元を正しく転写していると思われますが,写っていないL2では別人の特徴を転写してきています. このことから,ある程度コンテンツの意味を認識して転写が行われるということがわかります*18.
まとめ
本エントリでは,最新のRefSR手法であるSRNTTについて解説し,PyTorchによって再現実装を行いました. 実験結果から,参照画像から高解像度の特徴を転写してかなり綺麗な画像復元が可能となることを確認しました.
ただし結果を見るとわかるように,転写ができていない部分が辛いという問題があり,複数画像を参照するなど,研究を進める過程で改善していきたいと考えています. 本エントリも修正があれば随時更新していく所存です.
*1:Zhang, Zhifei, et al. "Image Super-Resolution by Neural Texture Transfer." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.
*2:Zhang, Zhifei, et al. "Reference-conditioned super-resolution by neural texture transfer." arXiv preprint arXiv:1804.03360 (2018).
*3:Freeman, William T., Egon C. Pasztor, and Owen T. Carmichael. "Learning low-level vision." International journal of computer vision 40.1 (2000): 25-47.
*4:Timofte, Radu, Vincent De Smet, and Luc Van Gool. "A+: Adjusted anchored neighborhood regression for fast super-resolution." Asian conference on computer vision. Springer, Cham, 2014.
*5:このような手法群をExample-based Super-Resolutionと言ったりします.
*6:Yue, Huanjing, et al. "Landmark image super-resolution by retrieving web images." IEEE Transactions on Image Processing 22.12 (2013): 4865-4878.
*7:Lowe, David G. "Distinctive image features from scale-invariant keypoints." International journal of computer vision 60.2 (2004): 91-110.
*8:Zheng, Haitian, et al. "Crossnet: An end-to-end reference-based super resolution network using cross-scale warping." Proceedings of the European Conference on Computer Vision (ECCV). 2018.
*9:なので,実質的には超解像の結果を参照画像を使って改善する手法とも言えます.
*10:SISRに比べて8-10[dB]ほど改善します
*11:拡大・縮小は基本的にBicubic補間で行いますが,どんな手法を使っても構いません(精度に対して影響はないらしい).
*12:Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarial network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
*13: upscaleは学習済みモデルの読み込みとデバッグのために用意してあります.
*14:Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual losses for real-time style transfer and super-resolution." European conference on computer vision. Springer, Cham, 2016.
*15:Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." Advances in neural information processing systems. 2017.
*16:Sajjadi, Mehdi SM, Bernhard Scholkopf, and Michael Hirsch. "Enhancenet: Single image super-resolution through automated texture synthesis." Proceedings of the IEEE International Conference on Computer Vision. 2017.
*17:余談ですが,4大スポーツで僕が唯一見ないスポーツです.
*18:別人貼り付けが良いか悪いかは置いておいて