-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathlaion_aesthetics.py
More file actions
80 lines (66 loc) · 2.82 KB
/
laion_aesthetics.py
File metadata and controls
80 lines (66 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
from transformers import CLIPVisionModelWithProjection, CLIPProcessor
from huggingface_hub import hf_hub_download
import os
from typing import Union
from PIL import Image
current_dir = os.path.dirname(os.path.abspath(__file__))
from base_verifier import BaseVerifier
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, embed):
return self.layers(embed)
class LAIONAestheticVerifier(BaseVerifier):
"""Based on https://github.com/christophschuhmann/improved-aesthetic-predictor."""
SUPPORTED_METRIC_CHOICES = ["laion_aesthetic_score"]
def __init__(self, **kwargs):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = kwargs.pop("dtype", torch.float32)
self.clip = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval()
self.clip.to(self.device, self.dtype)
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
path = hf_hub_download("trl-lib/ddpo-aesthetic-predictor", "aesthetic-model.pth")
state_dict = torch.load(path, weights_only=True, map_location=torch.device("cpu"))
self.mlp.load_state_dict(state_dict)
self.mlp.to(self.device, self.dtype)
def prepare_inputs(self, images: Union[list[Image.Image], Image.Image], prompts=None, **kwargs):
images = images if isinstance(images, list) else [images]
inputs = self.processor(images=images, return_tensors="pt")
inputs = inputs.to(device=self.device)
inputs = {k: v.to(self.dtype) for k, v in inputs.items()}
return inputs
def _score(self, inputs):
embed = self.clip(**inputs)[0]
# normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
scores = self.mlp(embed).squeeze(1)
return scores
@torch.no_grad()
@torch.inference_mode()
def score(self, inputs, **kwargs):
# TODO: consider batching inputs if they get too large.
# videos
if isinstance(inputs, list):
scores_per_video = []
for inputs_ in inputs:
score_for_video = self._score(inputs_)
scores_per_video.append({"laion_aesthetic_score": score_for_video.mean().item()})
return scores_per_video
# images
else:
scores = self._score(inputs)
return [{"laion_aesthetic_score": score.item()} for score in scores]