|
26 | 26 |
|
27 | 27 | import omegaconf |
28 | 28 |
|
29 | | -def get_domain_text_embs(model, cfg, source_text_prompts, target_text_prompts, class_names): |
30 | | - """ |
31 | | - Gets the text embeddings of the prompts describing the source and target domains. |
32 | | - If generic is True, source_text_prompts and target_text_prompts are strings instead of |
33 | | - templates to put the class name in. |
34 | | - """ |
35 | | - if cfg.AUGMENTATION.GENERIC: |
36 | | - text_embeddings = zeroshot_classifier(target_text_prompts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES) |
37 | | - text_embeddings = np.transpose(text_embeddings, (1,0)) |
38 | | - orig_prompts = text_embeddings |
39 | | - if len(source_text_prompts) > 0: |
40 | | - source_embeddings = zeroshot_classifier(source_text_prompts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES) |
41 | | - print("source emb before averaging", source_embeddings.shape) |
42 | | - source_embeddings = source_embeddings.mean(dim=0) |
43 | | - print("source emb after averaging", source_embeddings.shape) |
44 | | - diffs = torch.stack([emb-source_embeddings[0] for emb in text_embeddings]) |
45 | | - diffs /= text_embeddings.norm(dim=-1, keepdim=True) |
46 | | - else: |
47 | | - print(target_text_prompts) |
48 | | - # print("yo", len(source_text_prompts), len(source_text_prompts[0])) |
49 | | - # go on a per class basis |
50 | | - templates = target_text_prompts |
51 | | - all_texts = [] |
52 | | - for t in source_text_prompts: |
53 | | - texts = [[t.format(c)] for c in class_names] |
54 | | - text_emb = zeroshot_classifier(texts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES).T |
55 | | - print(texts, "text_emb", text_emb.shape) |
56 | | - all_texts.append(text_emb) |
57 | | - if type(target_text_prompts[0]) == str: |
58 | | - target_text_prompts = [target_text_prompts] |
59 | | - print(target_text_prompts) |
60 | | - for p in target_text_prompts: |
61 | | - print(p) |
62 | | - texts = [[t.format(c) for t in p] for c in class_names] |
63 | | - text_emb = zeroshot_classifier(texts, model, normalize=cfg.METHOD.NORMALIZE, model_type=cfg.EXP.IMAGE_FEATURES).T |
64 | | - all_texts.append(text_emb) |
65 | | - # this subtracts the neutral embedding from the domain embeddings and normalizes. |
66 | | - text_pairs = torch.stack(all_texts) |
67 | | - print("text pairs", text_pairs.shape) |
68 | | - target_embeddings, source_embeddings = text_pairs, [] |
69 | | - if len(source_text_prompts) > 0: |
70 | | - source_embeddings = text_pairs[:len(source_text_prompts)] |
71 | | - target_embeddings = text_pairs[len(source_text_prompts):] |
72 | | - else: |
73 | | - source_embeddings = torch.zeros_like(target_embeddings) |
74 | | - # text_diffs = [] |
75 | | - # source_domain = text_pairs[0] |
76 | | - # for target_domain in text_pairs[1:]: |
77 | | - # diff = target_domain - source_domain |
78 | | - # diff /= np.linalg.norm(diff, axis=-1, keepdims=True) |
79 | | - # # diff = np.expand_dims(diff, axis=0) |
80 | | - # text_diffs.append(diff) |
81 | | - # else: |
82 | | - # target_embeddings = text_pairs |
83 | | - # text_diffs = text_pairs |
84 | | - # diffs = torch.stack(text_diffs).permute(1,0,2) # should be (num_classes, num_domains, emb_size) |
85 | | - # print("diffs shape", diffs.shape) |
86 | | - # print("source embeddings", source_embeddings.shape) |
87 | | - print("target embeddings", target_embeddings.shape) |
88 | | - return source_embeddings, target_embeddings |
89 | | - |
90 | 29 | def get_features(dataset, model, device, model_type): |
91 | 30 | if model_type != 'clip' and model_type != 'openclip': |
92 | 31 | return get_resnet_features(dataset, model, device) |
|
0 commit comments