以图生图(图像到图像任务)¶
图像到图像任务是指应用程序接收一张图像并输出另一张图像的任务。这包括多种子任务,例如图像增强(超分辨率、低光增强、去雨等)、图像修复等。
本指南将向您展示如何:
- 使用图像到图像管道进行超分辨率任务,
- 不使用管道运行图像到图像模型进行相同的任务。
请注意,截至本指南发布时,image-to-image 管道仅支持超分辨率任务。
让我们从安装必要的库开始。
In [ ]:
pip install transformers
现在,我们可以使用 Swin2SR 模型 初始化管道。然后,我们可以通过用图像调用管道来进行推理。目前,此管道仅支持 Swin2SR 模型。
In [ ]:
from transformers import pipeline
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe = pipeline(task="image-to-image", model="caidas/swin2SR-lightweight-x2-64", device=device)
接下来,让我们加载一张图像。
In [ ]:
from PIL import Image
import requests
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/cat.jpg"
image = Image.open(requests.get(url, stream=True).raw)
print(image.size)
现在,我们可以使用管道进行推理。我们将得到猫图像的放大版本。
In [ ]:
upscaled = pipe(image)
print(upscaled.size)
如果您希望不使用管道自行进行推理,可以使用 Swin2SRForImageSuperResolution 和 Swin2SRImageProcessor 类。我们将使用相同的模型检查点进行此操作。让我们初始化模型和处理器。
In [ ]:
from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-lightweight-x2-64").to(device)
processor = Swin2SRImageProcessor("caidas/swin2SR-lightweight-x2-64")
pipeline 抽象化了我们需要自己做的预处理和后处理步骤,所以让我们预处理图像。我们将图像传递给处理器,然后将像素值移动到 GPU。
In [ ]:
pixel_values = processor(image, return_tensors="pt").pixel_values
print(pixel_values.shape)
pixel_values = pixel_values.to(device)
现在,我们可以通过将像素值传递给模型来进行图像推理。输出是一个 ImageSuperResolutionOutput 类型的对象,如下所示:
In [ ]:
import torch
with torch.no_grad():
outputs = model(pixel_values)
我们需要获取 reconstruction 并进行后处理以便可视化。让我们看看它是什么样子的。
In [ ]:
outputs.reconstruction.data.shape
# torch.Size([1, 3, 880, 1072])
我们需要压缩输出并去掉轴 0,裁剪值,然后将其转换为 numpy 浮点数。然后,我们将排列轴以获得形状 [1072, 880],最后,将输出范围恢复到 [0, 255]。
In [ ]:
import numpy as np
# 压缩,移至 CPU 并裁剪值
output = outputs.reconstruction.data.squeeze().cpu().clamp_(0, 1).numpy()
# 重新排列轴
output = np.moveaxis(output, source=0, destination=-1)
# 将值恢复到像素值范围
output = (output * 255.0).round().astype(np.uint8)
Image.fromarray(output)