diff --git a/edits/alignment/aligner.py b/edits/alignment/aligner.py index 580b225..c472dff 100644 --- a/edits/alignment/aligner.py +++ b/edits/alignment/aligner.py @@ -442,8 +442,8 @@ def reduce_inserts_deletions(alignment): assert len(example['src']) == len(example['tgt']) src, tgt = example['src'], example['tgt'] - src = [x.replace('PNX', '').replace('NIL','') for x in src] - tgt = [x.replace('PNX', '').replace('NIL','') for x in tgt] + src = [re.sub(r'\b(NIL|PNX)\b', '', x) for x in src] + tgt = [re.sub(r'\b(NIL|PNX)\b', '', x) for x in tgt] i = 0 s_idx = 0 diff --git a/edits/utils.py b/edits/utils.py index c039dd9..22af230 100644 --- a/edits/utils.py +++ b/edits/utils.py @@ -166,7 +166,14 @@ def insert_to_append(edits): # Special case for appends at the beginning of the sequence if processed_edits[0].startswith('A') and re.sub(r'A_\[.*?\]', '', processed_edits[0]) == 'K': - processed_edits[0] = processed_edits[0].replace('K', 'K' * len(subwords[0].replace('##', ''))) + # Calculate length of kept section + keep_section = 'K' * len(subwords[0].replace('##', '')) + # Split the edit by bracketed segments + edit_sections = re.split(r'(\[[^\]]*\])', processed_edits[0]) + # Replace 'K' only outside of the brackets + for i in range(0, len(edit_sections), 2): + edit_sections[i] = edit_sections[i].replace('K', keep_section) + processed_edits[0] = "".join(edit_sections) processed_edits = [SubwordEdit(subword, raw_subword, edit) for subword, raw_subword, edit in zip(subwords, raw_subwords, processed_edits)]