@@ -18,25 +18,40 @@ def softmax(self, x):
1818 return e_x / e_x .sum ()
1919
2020 def group_words (self , words ):
21- groups = {}
22- for word in words :
23- parts = word .replace ('+' , '' )
24- key = parts
25- group = groups .setdefault (key , [])
26- group .append (word )
27-
21+ if not words :
22+ return []
23+
2824 result = []
29- for group in groups . values ():
30- has_special_word = any ( word .replace ('+' , '' ) in self . special_words for word in group )
31- if has_special_word and len ( group ) > 3 :
32- subgroups = [ group [ i : i + 3 ] for i in range ( 0 , len ( group ), 3 )]
33- result . extend ( subgroups )
34- elif len ( group ) > 3 and len ( group ) % 2 == 0 :
35- subgroups = [ group [ i : i + 2 ] for i in range ( 0 , len ( group ), 2 )]
36- result . extend ( subgroups )
25+ current_group = [ words [ 0 ]]
26+ current_base = words [ 0 ] .replace ('+' , '' )
27+
28+ for word in words [ 1 :]:
29+ base_word = word . replace ( '+' , '' )
30+
31+ if base_word == current_base :
32+ current_group . append ( word )
3733 else :
38- result .append (group )
39-
34+ if current_base in self .special_words and len (current_group ) > 3 :
35+ subgroups = [current_group [i :i + 3 ] for i in range (0 , len (current_group ), 3 )]
36+ result .extend (subgroups )
37+ elif len (current_group ) > 3 and len (current_group ) % 2 == 0 :
38+ subgroups = [current_group [i :i + 2 ] for i in range (0 , len (current_group ), 2 )]
39+ result .extend (subgroups )
40+ else :
41+ result .append (current_group )
42+
43+ current_group = [word ]
44+ current_base = base_word
45+
46+ if current_base in self .special_words and len (current_group ) > 3 :
47+ subgroups = [current_group [i :i + 3 ] for i in range (0 , len (current_group ), 3 )]
48+ result .extend (subgroups )
49+ elif len (current_group ) > 3 and len (current_group ) % 2 == 0 :
50+ subgroups = [current_group [i :i + 2 ] for i in range (0 , len (current_group ), 2 )]
51+ result .extend (subgroups )
52+ else :
53+ result .append (current_group )
54+
4055 return result
4156
4257 def transfer_grouping (self , grouped_list , target_list ):
@@ -57,6 +72,8 @@ def classify(self, texts, hypotheses, num_hypotheses):
5772 #print("NO_BATCH")
5873 outs = []
5974 grouped_h = self .group_words (hypotheses )
75+ #print(grouped_h)
76+ #print(hypotheses)
6077 grouped_t = self .transfer_grouping (grouped_h , preprocessed_texts )
6178 for h , t in zip (grouped_h , grouped_t ):
6279 probs = []
0 commit comments