さしすせブログ

技術的なことを書きたい

微分可能画像処理ライブラリ Kornia への招待

f:id:S_aiueo321:20191201152115j:plain

はじめに

本エントリは Sansan Advent Calendar 2019 第9日目の記事です. R&D系の話題に限らず,多様な話題が投稿されますのでぜひウォッチしていただけると幸いです 🙏

adventar.org

TL; DR

Korniaはイイぞ

PyTorch + Computer Vision

PyTorchといえば,言わずと知れた機械学習フレームワークですよね. 検索数はここ1-2年で一気にChainerを抜き去り*1,TensorFlowにも肉薄しています. TensorFlowはプロダクションにおける人気が絶大ということもあって,研究サイドにフォーカスするとPyTorch人気はより一層強調されます*2

Computer Vision界隈においてPyTorchが人気な理由の一つとして,torchvision の存在があります. torchvisionは画像データローダにおける前処理や学習済みモデルを提供するライブラリです. 特に前処理が簡単に書けるのが良いところで,込み入ったことをしない限り,次のようにスッキリ書けます.

from pathlib import Path

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class OleOleDataset(Dataset):
    def __init__(self, data_dir):
        self.filenames = list(Path(data_dir).glob('*.png'))
        self.transforms = transforms.Compose([  # 前処置を積んでく
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index):
        filename = self.filenames[index]
        img = Image.open(filename).convert('RGB')
        return self.transforms(img)

    def __len__(self):
        return len(self.filenames)


dataset = OleOleDataset(data_dir='./hoge')
dataloader = DataLoader(dataset)

__getitem__()transforms()PIL.Image を渡していることからわかるように,torchvisionの前処理は基本的にPIL(Pillow)のラッパになっています. すごく便利ですが,人間は怠惰なので次のような要望が出てきます.

  1. PILを使いたくない (≒ NumPy, OpenCVを使いたい)
    • 👼「ndarray に変換すれば使える」
    • 👨「astype('uint8') とか RGB2BGR とかを一々書きたくないです」
  2. データローダ外でも使いたい
    • 👨「ネットワークで RGB2GRAY とか幾何的変換してからLossを計算したいのですが」
    • 👼「それはtorchvisionの守備範囲外では」
  3. せっかくだからGPU使いたい
    • 👼「cv2.cuda というのがあってだな」
    • 👨「辛そう…」

こんな問答を解決する一手として,今回の主題であります Kornia を以下で紹介します.

Kornia

KorniaはPyTorchをバックエンドとした 微分可能画像処理ライブラリ です. 元々はPyTorch Geometryという名前で幾何学的処理を提供するライブラリでしたが,グラフニューラルネットワークを扱うPyTorch Geometricと名前が競合することや,よりジェネリックな画像処理ライブラリへの拡張を考え,今の名前に改名されたらしいです. OpenCVにインスパイアされている部分が多く,実際に OpenCV.orgOpen Source Vision Foundation とパートナーシップを結んでいたりしています*3. 近々ではPyTorch ecosystemの仲間入りを果たしました.

以下では,WACV2020に採択された論文[*4]とライブラリの概観を説明します.なお,ライブラリの構成等は2019/12/08時点のものであり,stableバージョンは 0.1.4.post2 を指します.

kornia.github.io

arxiv.org

Korniaの特徴

論文ではKorniaには次のような特徴があるとしています.

  1. 微分可能
  2. CPU/GPUで簡単に実行可能
  3. ミニバッチでの処理が得意
  4. マルチプロセス/マルチノードで実行可能
  5. JITコンパイラが優秀

端的に言うと 「PyTorchがバックエンドだから」 という感じです. 殆どの演算がPyTrochの標準モジュール nn.Module を継承しており,一般的なCNNのレイヤと同等に扱うことができるのが一番の強みと思います.

ライブラリ概観

ライブラリは次のようなサブモジュールで構成されています.

サブモジュール 内容
kornia.color 色空間の変換
kornia.filters 画像フィルタリング
kornia.losses 誤差関数及び評価指標
kornia.feature 特徴点検出周りの処理
kornia.geometry 幾何学的変換周りの処理
kornia.utils ユーティリティ
kornia.contrib 実験的な実装

以下では,各サブモジュールの使い方をザッと見ていきます. テスト画像には lenna を用い,入力サイズはPyTorchの形式に沿って  N \times C \times H \times W とします.

kornia.color

kornia.color では,主に色空間変換 *5 と輝度やコントラストの調整機能を提供しています. まだstableではないですが,アルファブレンドなども提供予定っぽいです.

import kornia.color as color

# as torch.Tensor (shape: N x 3 x H x W)
org = imread('lenna.png')

# color space conversions
gray = color.rgb_to_grayscale(org)
hsv = color.rgb_to_hsv(org)
hls = color.rgb_to_hls(org)

# color adjustments
brightness = color.adjust_brightness(org, brightness_factor=0.5)
contrast = color.adjust_contrast(org, contrast_factor=0.5)
gamma = color.adjust_gamma(org, gamma=0.5)

結果はこんな感じです.

f:id:S_aiueo321:20191207002548p:plainf:id:S_aiueo321:20191207002604p:plainf:id:S_aiueo321:20191207002606p:plain
色空間の変換結果
f:id:S_aiueo321:20191207004142p:plainf:id:S_aiueo321:20191207004145p:plainf:id:S_aiueo321:20191207004147p:plain
輝度・コントラスト・ガンマの調整

kornia.filters

kornia.filters では現在ぼかしとエッジ検出用のフィルタリング処理を提供しています. エッジ検出はグレースケールで計算されるため,nn.Sequential() にグレースケール変換とエッジ検出をまとめています. これがまさにnn.Module を継承している強みで,torchvision.transforms.Compose に処理を積む感覚で処理をまとめることができます.

import kornia.color as color
import kornia.filters as filters

# imread as torch.Tensor (shape: N x 3 x H x W)
org = imread('lenna.png')

# blurring
box_blur = filters.box_blur(org, kernel_size=(9, 9))
median_blur = filters.median_blur(org, kernel_size=(9, 9))
gaussian_blur = filters.gaussian_blur2d(org, kernel_size=(9, 9), sigma=(10, 10))

# edge detection
laplacian_detector = nn.Sequential(
    color.RgbToGrayscale(),
    filters.Laplacian(kernel_size=9)
)
laplacian_edge = laplacian_detector(org)

sobel_detector = nn.Sequential(
    color.RgbToGrayscale(),
    filters.Sobel()
)
sobel_edge = sobel_detector(org)

f:id:S_aiueo321:20191207005336p:plainf:id:S_aiueo321:20191207005340p:plainf:id:S_aiueo321:20191207005342p:plain
ぼかしフィルタの適用結果
f:id:S_aiueo321:20191207010424p:plainf:id:S_aiueo321:20191207010420p:plain
エッジ検出結果

kornia.losses

kornia.losses では公式のnnモジュールに含まれない誤差関数を提供しています. タスクごとにまとめると次のようになります(※印はFuture Releaseです).

タスク モジュール名
Object detection losses.FocalLoss
Semantic segmentation

losses.DiceLoss
losses.TverskyLoss

Depth prediction losses.InverseDepthSmoothnessLoss
画質評価

losses.SSIM
losses.TotalVariation
losses.PSNRLoss

私個人は画像変換を取り扱うことが多いので,画質評価指標が手軽に使えるのが嬉しいです*6kornia.losses はそのまま誤差関数として使うようできているのでSSIMやPSNRみたいな類似度も相違度として定義されてます.すなわちこれらの類似度は大小を反転して利用する必要があります.

kornia.feature

kornia.feature では特徴点検出・特徴記述周りの機能を提供します.

特徴点検出ではスケールピラミッドの計算やオリエンテーションの計算などいくつか処理を踏むのですが,Korniaでは feature.ScaleSpaceDetector というラッパが用意されており,各段の処理をモジュールとして渡すことで様々な組み合わせの特徴点検出が可能となります. 例えばHarrisのコーナー検出は次のように書けます.

import kornia.feature as feature

# as torch.Tensor (shape: N x 1 x H x W)
org = imread('lenna.png')

detector = feature.ScaleSpaceDetector(
    num_features=50,
    resp_module=feature.CornerHarris(0.04)
)
lafs, resps = detector(org)

num_features は検出する特徴点の数を表します. 本来の特徴点検出では点の数に制限はありませんが,ミニバッチ化の際に固定サイズでなければならないため,常に num_features 個の特徴点が出力されます*7. 検出した点の位置とスケールを可視化すると次のようになります.

f:id:S_aiueo321:20191208174100p:plain
Harrisのコーナー検出の結果

加えてオリエンテーションの計算やアフィン領域推定をしたい場合, detector にそれぞれモジュールを追加します.

detector = kornia.feature.ScaleSpaceDetector(
    num_features=50,
    resp_module=kornia.feature.CornerHarris(0.04),
    aff_module=kornia.feature.LAFAffineShapeEstimator(patch_size=19),
    ori_module=kornia.feature.LAFOrienter(patch_size=19)
)

f:id:S_aiueo321:20191208185321p:plain
Harrisのコーナー検出+αの結果

検出した特徴点に対してSIFT特徴を記述するには,次のように書きます. 検出したスケールからパッチをサンプリングしてきて128次元のベクトルに記述します.

descriptor = feature.SIFTDescriptor(32)

patches =  feature.extract_patches_from_pyramid(org, lafs)
B, N, CH, H, W = patches.size()

descs = descriptor(patches.view(B * N, CH, H, W)).view(B, N, -1)
print (descs.shape)  # Out: (N x 50 x 128)

kornia.geometry

kornia.geometryでは,幾何学的変換だったり3次元点への変換だったりを提供します. サブモジュールは次のような感じ.やはりPyTorch Geometryだっただけあって, ここが一番手厚いです.

  • kornia.geometry.camera
    • Pinhole Camera
    • Perspective Camera
  • kornia.geometry.conversions
  • kornia.geometry.depth
  • kornia.geometry.linalg
  • kornia.geometry.transform
  • kornia.geometry.warp

僕が3Dとか何もわからない ライブラリを追いきれないので,今回はホモグラフィ行列による画像変換をやってみます. OpenCVと同様に,変換前後の4点座標からホモグラフィ行列を計算し,画像変換をするには次のようになります.

import kornia.geometry as geometry

org = imread('lenna.png')
n, c, h, w = org .shape

# 変換前後の4点座標
points_src = torch.FloatTensor([[
    [0, 0], [512, 0], [512, 512], [0, 512]]])
points_dst = torch.FloatTensor([[
    [0, 0], [300, 150], [400, 300], [250, 500]]])

M = kornia.get_perspective_transform(points_src, points_dst)
img_warp = kornia.warp_perspective(org, M, dsize=(h, w))

f:id:S_aiueo321:20191208204249p:plain
モグラフィ行列による変換結果

論文では画像のレジストレーションで著名なLucas–Kanade法を,kornia.geometryとPyTorchの自動微分&勾配降下を利用して実行する例が示されています. 機会があれば同じ要領で画像勾配を用いた画像合成なども試せたらいいかなと思います.

まとめ

本エントリではPyTorchをバックエンドとした微分可能画像処理ライブラリ Kornia を紹介しました. 現状バージョン 0.1.4.post2 と低く発展途上ですが,部分的に使える機能も多い印象です*8. このようなエコシステムが成熟すれば研究の幅も広がるので期待大です.

まぁ結局何が言いたいかというと,

CV勢,コントリビュートしましょう!

*1:抜き去るどころか… - https://preferred.jp/ja/news/pr20191205/

*2:CVPR2019の某ワークショップでは約2/3がPyTorchでした.

*3:というか論文のラストオーサーは完全にOpenCVの中の人

*4:Riba, Edgar, et al. "Kornia: an Open Source Differentiable Computer Vision Library for PyTorch." 2020 IEEE Winter Conference on Applications of Computer Vision (WACV). IEEE, 2020.

*5:密かにRGB2YCbCrのPR投げてました.これ以上は聞かないで.

*6:今までは自分で書いたり,0.3系で書かれたレポジトリをコピーする事がよくありました.

*7:特徴点が検出されないような均質な画像に対しても同様です.

*8:実際に今進めているプロジェクトに使っていたり使っていなかったり…

SRNTT: Image Super-Resolution by Neural Texture Transfer - Introduction and Implementation

f:id:S_aiueo321:20191106154046j:plain
Photo by hosein ashrafosadat from Pexels

こんにちは,さしすせ (Twitter, GitHub) と申します.初めて個人のブログを稼働させます. 現在修士課程に在学しており,大学・インターン先各所にて深層学習を用いた画像認識・生成の研究開発を行なっております. 本エントリでは,CVPR2019にて発表されたRefSR手法であるSRNTT (Image Super-Resolution by Neural Texture Transfer) [*1][*2]について解説・実装を行います.図は論文から引用しています.実装はPyTorchで行い,下記に置いてあります.

github.com

本エントリは,インターン先であります Navier株式会社 (ナビエ) にて取り組んでいる内容の一部となります. 弊社では,深層学習を用いた画像変換技術の開発およびソフトウェアの提供を行なっており,特にLow-Level Vision周りの技術に力を入れております. 本エントリをお読みになった方で興味をお持ちの方がいらっしゃいましたら,気軽にお問い合わせいただけると幸いです.

www.navier.co

TL; DR

  • RefSRは,入力画像に加えて参照画像を用いる超解像手法.
  • SRNTTは,RefSRの制約を緩めて広く利用可能にした手法.
  • PyTorchによる再現実装を行い,コードを公開した

目次

超解像

超解像 (Super-Resolution, SR) とは,入力信号の解像度を高めて出力する技術の総称であり,端的に言うと画像を大きくする技術です. 超解像は下図のようなImage Restoration 問題の一種と捉えられます.

f:id:S_aiueo321:20191106145304p:plain
Image Restoration問題としての超解像

Image Restoration 問題は,観測される低解像度画像  I^{LR} が高解像度画像  I^{HR} が劣化  \mathcal{D} を通って生成されると仮定し, \mathcal{D} の逆変換  \mathcal{F} を求めることで画像を再構成する問題です. Image Restoration 問題は典型的な逆問題であり,仮に \mathcal{D} が既知であっても  I^{LR} から  \mathcal{F} を一意に定めることは困難です. このような問題を ill-posed問題 といい,超解像においては拡大倍率が大きくなるほどill-posed性はきつくなっていきます.

近年,超解像の解法として畳み込みニューラルネットワーク(CNN)が多く用いられます. CNNベースの超解像のうち,単一のLR画像のみを入力とするSingle Image Super-Resolution (SISR)は,特に競争が激化している分野と言えます. 数多くのネットワーク構造や知覚品質を向上させるための学習方法などがここ数年で提案されてきましたが,上述のill-posed性により高倍率の復元にはまだまだ 難があり,単純なSISRでの精度向上には限界が見えつつあります.

RefSR

RefSR (Reference-based Super-Resolution) は,単一の画像を用いるSISRとは対照的に,参照画像を用いた超解像手法です. 画像検索などを使って取得できる高解像度な類似画像を参照画像とし,LR画像と似た部分を抽出して貼り合せることで画像を復元するようなイメージです. この考え方は結構本質的で,VISTA-Vision[*3]やA+[*4]などでNon-deep時代から検討されていますが,当時は画像というよりはパッチ単位で考えていたため,近年言われるようなセマンティックな部分をあまり考慮できていなかった側面があります*5. 本エントリでは,画像全体を参照するものをRefSRと呼ぶこととし,以下で関連研究を2つ例示します.

Landmark Image Super-Resolution by Retrieving Web Images [*6]

本手法(以下,Landmark)は局所特徴量をベースとした画像検索とレジストレーションによって構成された超解像手法です.Landmarkの処理の流れを次に示します.

f:id:S_aiueo321:20191106185536j:plain
Landmarkの処理の流れ

まず,低解像度画像  I^{l} をBicubic補間によって任意のサイズに拡大して補間画像  \tilde{I}を生成します. 次に, \tilde{I} に対してSIFT[*7]特徴量  \tilde{\Omega} を計算し, \tilde{\Omega}の Bag of Visual Words (BoVW) 表現を得ます. ここで求めたBoVW表現を用いて画像検索を行い,類似画像を複数枚ピックします. 類似画像においてもSIFT特徴量は計算済みであるため,この段階でキーポイントマッチングとRANSACによって射影変換行列を計算し,画像のレジストレーションを行います. 最後にパッチマッチングで微小なズレを補正し,エネルギー最小化によってパッチをブレンドすることで,出力画像  I^{h} を生成します.

CrossNet: An End-to-end Reference-based Super Resolution Network using Cross-scale Warping [*8]

Landmarkなどの従来手法では,参照画像を取得した後にレジストレーションやパッチマッチングを独立に行うため,最適化が難しいなどの課題がありました. CrossNetではこれらの処理をCNN内で行い,end-to-endでの学習を可能にしました.

f:id:S_aiueo321:20191106201023p:plain
CrossNetの構造

CrossNetへの入力画像は学習済みモデルによって事前に拡大しておきます*9. メインのネットワークはEncoder-Decoder構造をしており,Decoderで参照画像の特徴を利用して画像を復元します. ここで用いる参照画像の特徴は,Flow estimatorが求めたOptical flowによって変形され,Decoderに渡されています. これにより位置合わせがCNN内で完結し,end-to-end学習による安定した最適化と高精度化を達成しました*10

SRNTT

前置きが長くなりましたが,ここから本題となります.

上述のように,LandmarkやCrossNetは参照画像を幾何的変換によって変形して利用します. 言い換えると,利用可能な参照画像は幾何的変換によって位置合わせできる画像に限定されます. 建造物や風景画像を除けば,このような条件での画像収集は難しいため,RefSRの利用範囲は狭いものとなっていました.

SRNTTでは,幾何的な構造を意識せずに参照画像を利用し,RefSRの制約を 「似たものが映っていれば利用可能」 くらいに緩めることを可能にしました. 以下では,SRNTTの具体的なネットワーク構造や参照画像の利用方法について,実装を交えながら解説します.

ネットワーク構造

SRNTTの構造を次に示します.

f:id:S_aiueo321:20191109152235j:plain
SRNTTの構造

何やら複雑ですが,SRNTTは主に2つのブランチで構成されています. 一方は図中下のFeature swappingブランチで, I^{LR} I^{Ref} の特徴のマッチング・交換を行い,Swapped feature  M を計算します. 他方は図中上のTexture transferブランチで,  M を利用して  I^{Ref} のテクスチャを転写し,画像を再構成していきます.

Feature swapping

Feature swappingブランチでの処理は,簡単にいうと  I^{Ref} から  I^{LR} と似た部分を探してくる処理になります.  I^{LR} I^{Ref}の間には解像度のギャップがあるため,直接これらを比較することはできません. また,Texture transferブランチにマルチスケールな特徴を渡すため,拡大後の解像度空間で処理を行う必要があります. したがって,マッチングは  I^{LR} を拡大した  I^{LR\uparrow} I^{Ref} を直列に縮小・拡大した  I^{Ref \downarrow \uparrow} の間で行い, マッチング結果をもとに  I^{Ref} の特徴を参照し,Swapped feature  M を計算することになります*11

Patch matching

ここでは, I^{LR\uparrow} I^{Ref \downarrow \uparrow} の類似度を特徴空間上のパッチ単位で計算し,  I^{LR\uparrow} の各パッチに最も類似した  I^{Ref \downarrow \uparrow} のパッチを求めます.

まず, I^{LR\uparrow} I^{Ref \downarrow \uparrow} を特徴抽出器  \phi に通し,特徴マップ  \phi(I^{LR\uparrow}) \phi(I^{Ref \downarrow \uparrow}) を得ます.  \phi(I^{LR\uparrow}) \phi(I^{Ref \downarrow \uparrow}) をパッチに分解し, それぞれの  i j 番目のパッチを  P_i (\phi(I^{LR\uparrow})) P_j(\phi(I^{Ref \downarrow \uparrow}))とします. 任意の  i j ペアに対し,次のように内積を計算し,類似度  s_{i, j} を計算します. $$ s_{i, j} = \left\langle P_i(\phi(I^{LR\uparrow})), \frac{P_j(\phi(I^{Ref \downarrow \uparrow}))}{||P_j(\phi(I^{Ref \downarrow \uparrow}))||} \right\rangle $$ この計算は  \phi(I^{LR\uparrow}) に対するカーネル  P_j(\phi(I^{Ref \downarrow \uparrow})) による畳み込みと表現できるため,実際には  P_j ごとに類似度マップ  S_j を計算します. $$ S_{j} = \phi(I^{LR\uparrow}) \ast \frac{P_j(\phi(I^{Ref \downarrow \uparrow}))}{||P_j(\phi(I^{Ref \downarrow \uparrow}))||} $$ 最後に,位置  (x, y) における類似度  S_j(x, y) が最大となるインデックス  j^\ast を計算します. $$ j^\ast = \mathrm{arg} \max_{j} S_j (x, y) $$ これを全ての  (x, y) に対して計算し,マッチング結果とします. このように任意のパッチの組み合わせに対して類似度を計算するため,SRNTTにおいて従来手法のような幾何的な構造の制約はありません.

以上をコードに落とすと,次のようになります. map_in \phi(I^{LR\uparrow})map_ref_blur \phi(I^{Ref \downarrow \uparrow}) を表しています.  \phi にはVGG19を用い,relu3_1の出力に対してPatch matchingを行なっています.

import torch.nn.functonal as F

# relu3_1の特徴を抽出, shape: (C, H, W)
content = map_in['relu3_1']
condition = map_ref_blur['relu3_1']

# パッチに分解, shape: (C, patch_size, patch_size, n_patches)
patch_condition = sample_patches(condition)
# 正規化
patch_condition /= patch_condition.norm(p=2, dim=(0, 1, 2)) + 1e-5

# 類似度計算, shape: (1, n_patches, H-(patch_size//2), W-(patch_size//2))
sim = F.conv2d(content.unsqueeze(0), patch_condition.permute(3, 0, 1, 2))

# 類似度最大となるインデックスを計算
max_val, max_idx = sim.squeeze(0).max(dim=0)

理想的には,Patch matchingをTexture transferに渡す全てのスケールに対して行うのですが,論文では高速化のためrelu3_1でのみPatch matchingを計算し,その結果を用いて以降のTexture swappingに進むとしています.

Texture swapping

ここでは,Patch matchingの結果を元に  \phi(I^{Ref}) の特徴を参照していきます. Swapped feature  M の位置  (x, y) には $$ P_{\omega(x, y)}(M) = P_{j^\ast}(\phi(I^{Ref})) $$ が参照され,その値が足し込まれていきます.ここで  \omega(\cdot, \cdot) は位置  (x, y) とPatch matchingの結果を対応づける関数です. ただし, (x, y) を密にサンプルすると隣のパッチとオーバーラップが発生するので,最後に足し込まれた回数によって平均をとります.

以上をコードに落とすと次のようになります. Patch matchingのときと同様に,map_in \phi(I^{LR\uparrow})map_ref \phi(I^{Ref}) を表しています.

# 3つのスケールに対してTexture swappingを行う
swapped_maps = {}
for idx, layer in enumerate(['relu3_1', 'relu2_1', 'relu1_1']):
    # Patch matchingをrelu3_1でしかやっていないため,
    # relu2_1, relu1_1ではスケールに応じてパッチサイズとストライドを変更 (本来は不要)
    ratio = 2 ** idx
    _patch_size = self.patch_size * ratio
    _stride = self.stride * ratio

    # ターゲットとなるレイヤの特徴
    content = getattr(map_in, layer).squeeze(0)
    style = getattr(map_ref, layer).squeeze(0)

    # パッチに分解 (Ref画像は特徴をそのまま使いたいので正規化しない)
    patches_style = self.sample_patches(style, _patch_size, _stride)

    # マップを初期化
    target_map = torch.zeros_like(content).to(self.device)
    count_map = torch.zeros(target_map.shape[1:]).to(self.device)

    # 類似度最大のパッチを加算していく
    for i in range(max_idx.shape[0]):
        for j in range(max_idx.shape[1]):
            _i, _j = i * ratio, j * ratio
            target_map[:, _i:_i+_patch_size, _j:_j+_patch_size]\
                += patches_style[..., max_idx[i, j]]
            count_map[_i:_i+_patch_size, _j:_j+_patch_size] += 1
    target_map /= count_map  # 加算回数で平均

    # 辞書に追加
    swapped_maps.update({layer: target_map.cpu().numpy()})

Feature swappingの処理は以上となります. 先で述べたように,SRNTTにおけるFeature swappingはパッチ単位のマッチングにより位置関係を無視できます. この特性は複数の参照画像を簡単に利用できることを示唆していて,大規模なデータベースから画像を引いてこれればさらなる精度向上が期待できます. ただし,特徴マップの任意の位置で類似度を計算する関係上計算コストがかなり高いため,低いレイテンシを求められる状況には向かないという面もあります. 実際に,CUFED5データセット(主な画像サイズ: 500 x 330)に対してFeature swappingを行うと,推論そのものが 1[s] ほどなのに対し,Feature swappingは 30[s] ほどかかります. そのため,学習データに対してはオフラインでSwapped featureを計算しておく仕様になっています.

Texture transfer

さて,いよいよ画像を拡大するフェーズに来ました. Texture transferブランチは主に2つのコンポーネントで構成されています. 1つ目はContent extactorで一般的な超解像の特徴抽出部分と同じものになります. 2つ目はConditional texture transferで,Swapped featureを使って画像を再構成します.

Content extractor

論文ではContent extractorとしてSRGAN[*12]と同じ構造のネットワークが用いられていますが, 今回はMMSRで提供されている学習済みモデルを利用したかったため,Batch normalizationなどを廃したMSRGANの構造に合わせています. なおforward()の出力のうち,hがContent extractorの出力であり,Conditional texture transferに渡されます*13

import torch.nn as nn

class ContentExtractor(nn.Module):
    def __init__(self, ngf=64, n_blocks=16):
        super(ContentExtractor, self).__init__()

        self.head = nn.Sequential(
            nn.Conv2d(3, ngf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True)
        )
        self.body = nn.Sequential(
            *[ResBlock(ngf) for _ in range(n_blocks)],
        )
        self.tail = nn.Sequential(
            nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(ngf, ngf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        h = self.head(x)
        h = self.body(h) + h
        upscale = self.tail(h)
        return upscale, h

class ResBlock(nn.Module):
    def __init__(self, n_filters=64):
        super(ResBlock, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_filters, n_filters, 3, 1, 1),
            nn.ReLU(True),
            nn.Conv2d(n_filters, n_filters, 3, 1, 1),
        )

    def forward(self, x):
        return self.body(x) + x

Conditional texture transfer

ここでは,Content extractorの出力とSwapped featureを用いて画像を拡大していきます. 入力特徴を  \psi_l,そのスケールに対応するSwapped featureを  M_l とすると,次のスケールに渡される特徴  \psi_{l+1} は $$ \psi_{l+1} = [ \mathrm{Res}(\psi_{l} || M_{l}) + \psi_{l}]_{\uparrow 2 \times} $$ と表されます. \mathrm{Res}(\cdot) はResidual blocks,  || はチャネル方向のconcatを表しています. ネットワーク構造的には次のようなります.

f:id:S_aiueo321:20191113140012j:plain
Conditional texture transferの構造

倍率が4倍である場合, l=0,1,2 に対してこの計算を行っていきます.なお, l=2 においては拡大演算は行わず出力チャネルに合わせる畳み込みが適用されます.

これをコードにすると次のようになります.

class TextureTransfer(nn.Module):
    def __init__(self, ngf=64, n_blocks=16):
        super(TextureTransfer, self).__init__()

        # for small scale
        self.head_small = nn.Sequential(
            nn.Conv2d(ngf + 256, ngf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True),
        )
        self.body_small = nn.Sequential(
            *[ResBlock(ngf) for _ in range(n_blocks)],
        )
        self.tail_small = nn.Sequential(
            nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.1, True),
        )

        # for medium scale
        self.head_medium = nn.Sequential(
            nn.Conv2d(ngf + 128, ngf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True),
        )
        self.body_medium = nn.Sequential(
            *[ResBlock(ngf) for _ in range(n_blocks)],
        )
        self.tail_medium = nn.Sequential(
            nn.Conv2d(ngf, ngf * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.1, True),
        )

        # for large scale
        self.head_large = nn.Sequential(
            nn.Conv2d(ngf + 64, ngf, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True),
        )
        self.body_large = nn.Sequential(
            *[ResBlock(ngf) for _ in range(n_blocks)],
        )
        self.tail_large = nn.Sequential(
            nn.Conv2d(ngf, ngf // 2, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(ngf // 2, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x, maps):
        # small scale
        h = torch.cat([x, maps['relu3_1']], 1)
        h = self.head_small(h)
        h = self.body_small(h) + x
        x = self.tail_small(h)

        # medium scale
        h = torch.cat([x, maps['relu2_1']], 1)
        h = self.head_medium(h)
        h = self.body_medium(h) + x
        x = self.tail_medium(h)

        # large scale
        h = torch.cat([x, maps['relu1_1']], 1)
        h = self.head_large(h)
        h = self.body_large(h) + x
        x = self.tail_large(h)

        return x

目的関数

Reconstruction loss

ここではピクセルごとのL1距離をReconstruction loss  L_{rec} とします. 超解像ではMSEよりL1距離を用いるほうがシャープな結果が得られるとされていてよく用いられます. $$ \mathcal{L}_{rec} = || I^{HR} - I^{SR}||_1 $$

Perceptual loss

ここではJohnson et al. [*14] にならい,VGG特徴のノルムを計算します. 論文ではノルムにフロベニウスノルムが使用されています. $$ \mathcal{L}_{per} = || \phi_i(I^{HR}) - \phi_i(I^{SR})||_F $$

Adversarial loss

WGAN-GP [*15]と同じ損失関数を定義します. 識別機の学習時にはここにGradient penaltyが加算されて学習の安定化を図ります. $$ \quad\quad\quad\quad\quad \mathcal{L}_{adv} = -\mathbb{E}_{\tilde{x} \sim \mathbb{P}_g}[D(\tilde{x})], \\ \min_{G} \max_{D \in G} \mathbb{E}_{x \sim \mathbb{P}_r}[D(x)] -\mathbb{E}_{\tilde{x} \sim \mathbb{P}_g}[D(\tilde{x})], $$

Texture loss

ここではENet [*16] にならい,VGG特徴とSwapped feature  M のGram行列をマルチスケールで比較します.  l はスケールを表しています. $$ \mathcal{L}_{t ex} = \sum_l \lambda_{l} || Gr( \phi_{l} (I^{SR}) \cdot S^\ast_l )- Gr(M_l \cdot S^\ast_l)||_F $$  M の計算時に最大値を取っている関係上,似たものが写っていない状況では類似度の低い特徴を引いてきている可能性があります. 類似度の高い特徴を優先的に利用するため,  j^\ast における類似度マップ  S^\ast を用いて重み付けを行なった上で比較を行います.

Overall objective

上記の目的関数を線形結合して目的関数  \mathcal{L} を定義します.  \lambda_{per} \lambda_{adv} \lambda_{te x}は各関数の重みを表していて,論文ではそれぞれ  1\mathrm{e}{-4} 1\mathrm{e}{-6} 1\mathrm{e}{-4}に設定されます. $$ \mathcal{L} = \mathcal{L}_{rec} + \lambda_{per} \mathcal{L}_{per} + \lambda_{adv} \mathcal{L}_{adv} + \lambda_{te x} \mathcal{L}_{te x} $$

実験

再掲になりますが,著者によるTensorFlow実装を参考に,SRNTTをPyTorchによって再現実装を行いました.

github.com

以下の実験結果において,SRNTTとして提示する画像は全て再現実装の出力となります.

条件

基本的に論文と同じ条件で学習を行います.

  • データセット
    • 学習データ: CUFED
      • train:val=9:1 にランダムスプリット.
    • テストデータ: CUFED5
      • 入力画像と類似度順に5枚の画像がセットになっているデータセット
  • 倍率:  \times 4
  • 学習回数
    • 事前学習:   \mathcal{L}_{rec} のみを使って5エポック
      • この際,Content extractorの初期値にはMMSRが提供するMSRGANの重みを使用.
    • メインの学習:  \mathcal{L} で100エポック
  • バッチサイズ: 9
  • 学習率
    • 初期値:  1\mathrm{e}{-4}
    • スケジューラ: 50エポックで  \times 0.1
  • オプティマイザ: Adam

結果

まずテストセットにおけるPSNRについて,論文が示す結果は次のようになります. HRは正解画像をランダムに変形したものを参照画像とした結果であり,その他は類似画像に対する結果で類似度順にL1からL5まで並んでいます.

f:id:S_aiueo321:20191113152716p:plain
論文におけるSRNTTの定量評価結果

これに対し,手元で学習したモデルのPSNRは次のようになりました. おおよそ再現できたのかなという感じです.

Ref画像 HR L1 L2 L3 L4 L5
PSNR[dB] 34.54 26.60 26.08 25.94 25.81 25.75

続いて,出力画像を見ていきます. SRNTT(HR)はHR画像自身を参照画像とした結果,SRNTT(L3)はRef Image(L3)を参照画像とした結果となります. この結果では,入力画像と参照画像に共通してNHL*17のロゴが写っており,SRNTT(L3)ではその部分を転写できていることが確認できます.

f:id:S_aiueo321:20191113162033p:plain
出力画像の比較 (CUFED5 `002`)

さらによく見ると,HR画像のロゴが若干ブレているため,ブレのないL3を参照したほうが結果がシャープになっていることがわかります. このように,質の高い参照画像が手に入る環境では,原画像を超える復元が可能となる場合があります.

次の結果では某夢の国っぽいお城の構造が転写できることが確認できます.

f:id:S_aiueo321:20191116083633p:plain
出力画像の比較 (CUFED5 `065`)

最後に少し面白い結果ですが,下段は上段を参照画像とした際の出力です. 祖母が写っているL3では口元を正しく転写していると思われますが,写っていないL2では別人の特徴を転写してきています. このことから,ある程度コンテンツの意味を認識して転写が行われるということがわかります*18

f:id:S_aiueo321:20191116084230p:plain
参照画像による画像の変化

まとめ

本エントリでは,最新のRefSR手法であるSRNTTについて解説し,PyTorchによって再現実装を行いました. 実験結果から,参照画像から高解像度の特徴を転写してかなり綺麗な画像復元が可能となることを確認しました.

ただし結果を見るとわかるように,転写ができていない部分が辛いという問題があり,複数画像を参照するなど,研究を進める過程で改善していきたいと考えています. 本エントリも修正があれば随時更新していく所存です.

*1:Zhang, Zhifei, et al. "Image Super-Resolution by Neural Texture Transfer." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

*2:Zhang, Zhifei, et al. "Reference-conditioned super-resolution by neural texture transfer." arXiv preprint arXiv:1804.03360 (2018).

*3:Freeman, William T., Egon C. Pasztor, and Owen T. Carmichael. "Learning low-level vision." International journal of computer vision 40.1 (2000): 25-47.

*4:Timofte, Radu, Vincent De Smet, and Luc Van Gool. "A+: Adjusted anchored neighborhood regression for fast super-resolution." Asian conference on computer vision. Springer, Cham, 2014.

*5:このような手法群をExample-based Super-Resolutionと言ったりします.

*6:Yue, Huanjing, et al. "Landmark image super-resolution by retrieving web images." IEEE Transactions on Image Processing 22.12 (2013): 4865-4878.

*7:Lowe, David G. "Distinctive image features from scale-invariant keypoints." International journal of computer vision 60.2 (2004): 91-110.

*8:Zheng, Haitian, et al. "Crossnet: An end-to-end reference-based super resolution network using cross-scale warping." Proceedings of the European Conference on Computer Vision (ECCV). 2018.

*9:なので,実質的には超解像の結果を参照画像を使って改善する手法とも言えます.

*10:SISRに比べて8-10[dB]ほど改善します

*11:拡大・縮小は基本的にBicubic補間で行いますが,どんな手法を使っても構いません(精度に対して影響はないらしい).

*12:Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarial network." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.

*13: upscaleは学習済みモデルの読み込みとデバッグのために用意してあります.

*14:Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual losses for real-time style transfer and super-resolution." European conference on computer vision. Springer, Cham, 2016.

*15:Gulrajani, Ishaan, et al. "Improved training of wasserstein gans." Advances in neural information processing systems. 2017.

*16:Sajjadi, Mehdi SM, Bernhard Scholkopf, and Michael Hirsch. "Enhancenet: Single image super-resolution through automated texture synthesis." Proceedings of the IEEE International Conference on Computer Vision. 2017.

*17:余談ですが,4大スポーツで僕が唯一見ないスポーツです.

*18:別人貼り付けが良いか悪いかは置いておいて