@@ -10,11 +10,16 @@ int main(int argc, char ** argv) {
1010 return 1 ;
1111 }
1212
13- int n_labels = params.texts .size ();
13+ const size_t n_labels = params.texts .size ();
1414 if (n_labels < 2 ) {
1515 printf (" %s: You must specify at least 2 texts for zero-shot labeling\n " , __func__);
1616 }
1717
18+ const char * labels[n_labels];
19+ for (size_t i = 0 ; i < n_labels; ++i) {
20+ labels[i] = params.texts [i].c_str ();
21+ }
22+
1823 auto ctx = clip_model_load (params.model .c_str (), params.verbose );
1924 if (!ctx) {
2025 printf (" %s: Unable to load model from %s" , __func__, params.model .c_str ());
@@ -23,40 +28,21 @@ int main(int argc, char ** argv) {
2328
2429 // load the image
2530 const auto & img_path = params.image_paths [0 ].c_str ();
26- clip_image_u8 img0;
27- clip_image_f32 img_res;
28- if (!clip_image_load_from_file (img_path, &img0)) {
31+ clip_image_u8 input_img;
32+ if (!clip_image_load_from_file (img_path, &input_img)) {
2933 fprintf (stderr, " %s: failed to load image from '%s'\n " , __func__, img_path);
3034 return 1 ;
3135 }
3236
33- const int vec_dim = clip_get_vision_hparams (ctx)->projection_dim ;
34-
35- clip_image_preprocess (ctx, &img0, &img_res);
36-
37- float img_vec[vec_dim];
38- if (!clip_image_encode (ctx, params.n_threads , &img_res, img_vec, false )) {
37+ float sorted_scores[n_labels];
38+ int sorted_indices[n_labels];
39+ if (!clip_zero_shot_label_image (ctx, params.n_threads , &input_img, labels, n_labels, sorted_scores, sorted_indices)) {
40+ fprintf (stderr, " Unable to apply ZSL\n " );
3941 return 1 ;
4042 }
4143
42- // encode texts and compute similarities
43- float txt_vec[vec_dim];
44- float similarities[n_labels];
45-
46- for (int i = 0 ; i < n_labels; i++) {
47- const auto & text = params.texts [i].c_str ();
48- auto tokens = clip_tokenize (ctx, text);
49- clip_text_encode (ctx, params.n_threads , &tokens, txt_vec, false );
50- similarities[i] = clip_similarity_score (img_vec, txt_vec, vec_dim);
51- }
52-
53- // apply softmax and sort scores
54- float sorted_scores[n_labels];
55- int indices[n_labels];
56- softmax_with_sorting (similarities, n_labels, sorted_scores, indices);
57-
5844 for (int i = 0 ; i < n_labels; i++) {
59- auto label = params. texts [indices [i]]. c_str () ;
45+ auto label = labels[sorted_indices [i]];
6046 float score = sorted_scores[i];
6147 printf (" %s = %1.4f\n " , label, score);
6248 }
0 commit comments