UNO Less-to-More Generalization: 通过上下文生成解锁更多可控性
UNO Less-to-More Generalization: 通过上下文生成解锁更多可控性
Less-to-More Generalization: Unlocking More Controllability by In-Context Generation
flyfish
AI 绘画工具只要输入文字就能画出图片,不过呢,有时候光有一张随便生成的图片可满足不了我们。我们还想更精准地 “指挥” 它画画,比如让某个我们指定的人或者东西出现在画里,甚至能让好几个特定的人或者东西一起出现在画里,而且还得让它们都保留着自己原本的样子和特点,UNO就像解决这个问题。
import os
import dataclasses
from typing import Literal
from accelerate import Accelerator
from transformers import HfArgumentParser
from PIL import Image
import json
import itertools# 图像拼接工具类
class ImageConcatenator:@staticmethoddef horizontal_concat(images):widths, heights = zip(*(img.size for img in images))total_width = sum(widths)max_height = max(heights)new_im = Image.new('RGB', (total_width, max_height))x_offset = 0for img in images:new_im.paste(img, (x_offset, 0))x_offset += img.size[0]return new_im# 推理参数类
@dataclasses.dataclass
class InferenceArgs:prompt: str | None = Noneimage_paths: list[str] | None = Noneeval_json_path: str | None = Noneoffload: bool = Falsenum_images_per_prompt: int = 1model_type: Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] = "flux-dev"width: int = 512height: int = 512ref_size: int = -1num_steps: int = 25guidance: float = 4seed: int = 3407save_path: str = "output/inference"only_lora: bool = Trueconcat_refs: bool = Falselora_rank: int = 512data_resolution: int = 512pe: Literal['d', 'h', 'w', 'o'] = 'd'# 数据加载器类
class DataLoader:def __init__(self, args):self.args = argsdef load(self):assert self.args.prompt is not None or self.args.eval_json_path is not None, \"Please provide either prompt or eval_json_path"if self.args.eval_json_path is not None:with open(self.args.eval_json_path, "rt") as f:data_dicts = json.load(f)data_root = os.path.dirname(self.args.eval_json_path)else:data_root = "./"data_dicts = [{"prompt": self.args.prompt, "image_paths": self.args.image_paths}]return data_dicts, data_root# 参考图像预处理类
class RefImagePreprocessor:def __init__(self, args):self.args = argsdef preprocess(self, data_dict, data_root):ref_imgs = [Image.open(os.path.join(data_root, img_path))for img_path in data_dict["image_paths"]]if self.args.ref_size == -1:self.args.ref_size = 512 if len(ref_imgs) == 1 else 320from uno.flux.pipeline import preprocess_refref_imgs = [preprocess_ref(img, self.args.ref_size) for img in ref_imgs]return ref_imgs# 图像生成器类
class ImageGenerator:def __init__(self, args, accelerator):self.args = argsself.accelerator = acceleratorfrom uno.flux.pipeline import UNOPipelineself.pipeline = UNOPipeline(self.args.model_type,self.accelerator.device,self.args.offload,only_lora=self.args.only_lora,lora_rank=self.args.lora_rank)def generate(self, data_dict, ref_imgs, j):image_gen = self.pipeline(prompt=data_dict["prompt"],width=self.args.width,height=self.args.height,guidance=self.args.guidance,num_steps=self.args.num_steps,seed=self.args.seed + j,ref_imgs=ref_imgs,pe=self.args.pe,)if self.args.concat_refs:image_gen = ImageConcatenator.horizontal_concat([image_gen, *ref_imgs])return image_gen# 结果保存器类
class ResultSaver:def __init__(self, args):self.args = argsdef save(self, image_gen, i, j, data_dict):os.makedirs(self.args.save_path, exist_ok=True)image_gen.save(os.path.join(self.args.save_path, f"{i}_{j}.png"))args_dict = vars(self.args)args_dict['prompt'] = data_dict["prompt"]args_dict['image_paths'] = data_dict["image_paths"]with open(os.path.join(self.args.save_path, f"{i}_{j}.json"), 'w') as f:json.dump(args_dict, f, indent=4)# 推理引擎类
class InferenceEngine:def __init__(self, args):self.args = argsself.accelerator = Accelerator()self.data_loader = DataLoader(args)self.ref_preprocessor = RefImagePreprocessor(args)self.image_generator = ImageGenerator(args, self.accelerator)self.result_saver = ResultSaver(args)def run(self):data_dicts, data_root = self.data_loader.load()for (i, data_dict), j in itertools.product(enumerate(data_dicts), range(self.args.num_images_per_prompt)):if (i * self.args.num_images_per_prompt + j) % self.accelerator.num_processes != self.accelerator.process_index:continueref_imgs = self.ref_preprocessor.preprocess(data_dict, data_root)image_gen = self.image_generator.generate(data_dict, ref_imgs, j)self.result_saver.save(image_gen, i, j, data_dict)def main():parser = HfArgumentParser([InferenceArgs])args = parser.parse_args_into_dataclasses()[0]engine = InferenceEngine(args)engine.run()if __name__ == "__main__":main()
使用示例
先告诉程序模型在哪?
export CLIP=/home/hub/models/AI-ModelScope/clip-vit-large-patch14/
export LORA=/home/hub/models/bytedance-research/UNO/dit_lora.safetensors
export T5=/home/hub/models/xlabs-ai/xflux_text_encoders/
export AE=/home/hub/models/black-forest-labs/FLUX___1-dev/ae.safetensors
export FLUX_DEV=/home/hub/models/black-forest-labs/FLUX___1-dev/flux1-dev.safetensors
上面的代码保存为your_script_name.py
1. 直接提供提示词和参考图像路径
若要直接提供提示词和参考图像路径,可以使用以下命令:
python your_script_name.py --prompt "A beautiful landscape" --image_paths "path/to/image1.jpg" "path/to/image2.jpg"
2. 使用 JSON 文件进行推理
若有一个包含推理数据的 JSON 文件,可以使用以下命令:
python your_script_name.py --eval_json_path "path/to/eval.json"
JSON 文件格式
JSON 文件应包含一个列表,列表中的每个元素为一个字典,字典包含 prompt
和 image_paths
字段,示例如下:
[{"prompt": "A cute cat","image_paths": ["path/to/cat1.jpg", "path/to/cat2.jpg"]},{"prompt": "A big dog","image_paths": ["path/to/dog1.jpg", "path/to/dog2.jpg"]}
]
- 必须提供
prompt
或eval_json_path
中的一个。 - 若
ref_size
为-1
,会根据参考图像数量自动设置大小。 - 生成的图像和对应的配置信息会保存到
save_path
指定的路径下。
参数说明
参数名 | 类型 | 默认值 | 说明 |
---|---|---|---|
prompt | str 或 None | None | 用于生成图像的提示词。如果不提供,需提供 eval_json_path 。 |
image_paths | list[str] 或 None | None | 参考图像的路径列表。 |
eval_json_path | str 或 None | None | 包含推理数据的 JSON 文件路径。如果不提供,需提供 prompt 。 |
offload | bool | False | 是否进行卸载操作。 |
num_images_per_prompt | int | 1 | 每个提示词生成的图像数量。 |
model_type | Literal["flux-dev", "flux-dev-fp8", "flux-schnell"] | "flux-dev" | 模型类型。 |
width | int | 512 | 生成图像的宽度。 |
height | int | 512 | 生成图像的高度。 |
ref_size | int | -1 | 参考图像的大小。若为 -1 ,则根据参考图像数量自动设置。 |
num_steps | int | 25 | 推理步数。 |
guidance | float | 4 | 引导系数。 |
seed | int | 3407 | 随机种子。 |
save_path | str | "output/inference" | 保存生成图像和配置文件的路径。 |
only_lora | bool | True | 是否仅使用 LoRA。 |
concat_refs | bool | False | 是否将参考图像与生成的图像拼接。 |
lora_rank | int | 512 | LoRA 的秩。 |
data_resolution | int | 512 | 数据分辨率。 |
pe | Literal['d', 'h', 'w', 'o'] | 'd' | 位置编码类型。 |
我们在一个统一的模型中整合了单主体和多主体生成。对于单主体场景,默认将参考图像的最长边设置为512;而对于多主体场景,则将其设置为320。得益于在多尺度数据集上的训练,UNO 在各种宽高比下展现了显著的灵活性。尽管其训练是在512的分辨率范围内进行的,但它能够处理更高的分辨率,包括512、568和704等。
使用的模型介绍
T5(Text-to-Text Transfer Transformer)
T5 是 Google 提出的,全称是 Text-to-Text Transfer Transformer,属于 Transformer 架构的一种。它的核心思想是把所有的 NLP 任务都转化为文本生成问题,比如翻译、问答、摘要等,都可以看作是输入文本到输出文本的转换。
简单说,它的核心本事是把 所有任务都变成“翻译”——不管你让它翻译语言、写文章摘要、回答问题,还是生成故事、做数学题,它都当成“输入一段文字,输出一段文字”来处理。
它能干啥?
- 翻译机:比如把中文翻译成英文,或者反过来,甚至小众语言也能翻(虽然不如专业翻译,但基础够用)。
- 摘要小能手:给一篇超长的文章,它能自动提炼出核心内容,比如“帮你快速看懂一篇新闻讲了啥”。
- 问答机器人:你问它问题,比如“怎么煮奶茶?”,它能像真人一样给你写步骤回答。
- 文章生成器:给个开头,它能接着写下去,比如帮你续写小说、生成邮件内容,甚至写广告文案。
- 逻辑小助手:简单的数学题、逻辑推理题,它也能通过文字推导给出答案(比如“小明比小红大3岁,小红5岁,小明几岁?”)。
为啥厉害?
- 万能统一:以前不同的文字任务(翻译、摘要、问答)需要不同的模型,T5只用一套“输入-输出”逻辑就全搞定了,就像一个工具能当锤子、螺丝刀、剪刀用。
- 学习能力强:它先“读”了互联网级别的海量文字(比如书籍、网页、文章),学会了人类语言的规律,然后针对具体任务(比如翻译)稍微“调整”一下,就能快速上手。
- 效果不错:在很多文字处理任务上,它的表现接近甚至超过人类,比如复杂句子的翻译、长文总结,准确率挺高。
例子
- 智能客服,背后可能有T5帮着理解你的问题并生成回复;
- 一些“自动写新闻”的AI,可能用T5来根据数据生成文字;
- 写论文时用的“自动摘要工具”,说不定就是T5改的。
CLIP 模型:让图片和文字能 “对话” 的跨模态高手
CLIP 就像一个 “图文翻译官”,让图片和文字能互相理解,不用大量定制化训练就能完成跨模态任务,既是科研人员的好帮手,也是未来创意 AI 的重要基石。虽然现在还有小缺点,但已经让我们看到了 “AI 能同时看懂图和字” 的无限可能。
想象一下,你给 AI 看一张猫的照片,同时输入 “一只戴蝴蝶结的猫”,它能立刻知道两者在说同一件事 —— 这就是 OpenAI 的 CLIP 模型在干的事。作为专门处理图像和文字关系的 “跨模态翻译官”,它的核心技能是把图片和文字都转化成同一种 “语言”(语义空间),让机器能看懂它们之间的相似性。
1 不用专门学就能认图:零样本分类
传统 AI 认图需要先喂大量标注好的图片(比如 10 万张猫的照片标上 “猫”),但 CLIP 像个 “触类旁通” 的学霸:
操作简单:你只要告诉它几个类别名称(比如 “猫、狗、汽车”),不用额外训练,它就能对着新图片判断属于哪个类别。
场景举例:比如你拍了张没见过的昆虫照片,直接输入几个候选名称,它能秒级匹配出最接近的类别,就像带了个 “万能图鉴”。
代码像搭积木:用现成的处理工具加载图片和文字,模型自动算相似度,新手也能快速上手(见后文示例)。
2 图文互搜:找图找文字都在行
CLIP 就像一个 “双向翻译机”:
按文找图:比如在搜索引擎输入 “夕阳下的沙滩排球”,它能从海量图片里精准捞出匹配的画面,比传统关键词搜索更智能(关键词可能漏词,但 CLIP 懂语义)。
按图找文:反过来,上传一张抽象画,它能匹配出最接近的文字描述,比如 “蓝色调抽象派画作,笔触粗犷”,适合艺术检索或盲文辅助。
应用场景:电商平台用它做 “拍照搜商品”,社交平台用它按文案推荐配图,都是 CLIP 的拿手好戏。
3 跨模态研究的 “瑞士军刀”
虽然普通用户感受不到,但 CLIP 在科研圈超火,因为它能帮科学家解决这些问题:
语言差异:看看模型在中文、日文里的理解准不准,有没有文化差异。
抗干扰能力:比如图片模糊、文字有错别字时,模型还能不能认出来。
公平性问题:检查模型会不会对某些群体(比如不同肤色、性别)有偏见,比如看到护士图就默认是女性,需要优化。
4 给生成模型当 “军师”
CLIP 自己不会画图片,但能当 “指挥官”:
指导画画:比如你让 AI 生成 “穿西装的企鹅在弹钢琴”,CLIP 会先判断生成的图片是否符合描述,不符合就让画画模型(比如 DALL-E、 Stable Diffusion)调整,相当于给生成模型装了个 “语义质检仪”。
艺术创作:艺术家可以用它筛选风格,比如输入 “梵高风格”,CLIP 会帮模型锁定类似的色彩和笔触,让生成更精准。
5 下游任务的 “多面手”
稍微调一调参数,CLIP 就能跨界打工:
写图片说明:给图片自动配文字,比如 “三只松鼠在树上吃坚果”。
回答视觉问题:比如对着猫的图片问 “这只猫是什么颜色?”,它能结合图像和问题给出答案。
跨模态检索:帮图书馆管理系统按文字找插图,或按历史图片找对应文献。
6 普通人也能玩的创意工具
CLIP 的灵活度让它成了 “脑洞孵化器”:
艺术搞怪:输入 “把蒙娜丽莎变成赛博朋克风格”,结合生成模型就能产出魔改作品。
内容审核:自动识别图片是否符合 “禁止暴力” 的文字规则,比关键词匹配更智能。
游戏开发:给游戏场景生成动态描述,比如玩家走到森林,AI 自动生成 “潮湿的泥土气息,远处传来鸟鸣” 的文案。
它的 “小毛病” 也得知道:
标签质量决定下限,如果给的类别名称很模糊(比如把 “金毛犬” 标成 “狗”),模型可能分不清具体品种。
中文能力弱一些,毕竟主要在英文数据里长大,处理中文时可能 “水土不服”,比如复杂成语或网络热词容易理解错。
数据偏见跟着走
直接用有风险,不能直接搬来用,不然容易出错。