-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathutterance.py
More file actions
115 lines (96 loc) · 4.45 KB
/
utterance.py
File metadata and controls
115 lines (96 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
""" Contains the Utterance class. """
import sql_util
import tokenizers
ANON_INPUT_KEY = "cleaned_nl"
OUTPUT_KEY = "sql"
class Utterance:
""" Utterance class. """
def process_input_seq(self,
anonymize,
anonymizer,
anon_tok_to_ent):
assert not anon_tok_to_ent or anonymize
assert not anonymize or anonymizer
if anonymize:
assert anonymizer
self.input_seq_to_use = anonymizer.anonymize(
self.original_input_seq, anon_tok_to_ent, ANON_INPUT_KEY, add_new_anon_toks=True)
else:
self.input_seq_to_use = self.original_input_seq
def process_gold_seq(self,
output_sequences,
nl_to_sql_dict,
available_snippets,
anonymize,
anonymizer,
anon_tok_to_ent):
# Get entities in the input sequence:
# anonymized entity types
# othe recognized entities (this includes "flight")
entities_in_input = [
[tok] for tok in self.input_seq_to_use if tok in anon_tok_to_ent]
entities_in_input.extend(
nl_to_sql_dict.get_sql_entities(
self.input_seq_to_use))
# Get the shortest gold query (this is what we use to train)
shortest_gold_and_results = min(output_sequences,
key=lambda x: len(x[0]))
# Tokenize and anonymize it if necessary.
self.original_gold_query = shortest_gold_and_results[0]
self.gold_sql_results = shortest_gold_and_results[1]
self.contained_entities = entities_in_input
# Keep track of all gold queries and the resulting tables so that we can
# give credit if it predicts a different correct sequence.
self.all_gold_queries = output_sequences
self.anonymized_gold_query = self.original_gold_query
if anonymize:
self.anonymized_gold_query = anonymizer.anonymize(
self.original_gold_query, anon_tok_to_ent, OUTPUT_KEY, add_new_anon_toks=False)
# Add snippets to it.
self.gold_query_to_use = sql_util.add_snippets_to_query(
available_snippets, entities_in_input, self.anonymized_gold_query)
def __init__(self,
example,
available_snippets,
nl_to_sql_dict,
params,
anon_tok_to_ent={},
anonymizer=None):
# Get output and input sequences from the dictionary representation.
output_sequences = example[OUTPUT_KEY]
self.original_input_seq = tokenizers.nl_tokenize(example[params.input_key])
self.available_snippets = available_snippets
self.keep = False
pruned_output_sequences = []
for sequence in output_sequences:
if len(sequence[0]) > 3:
pruned_output_sequences.append(sequence)
output_sequences = pruned_output_sequences
if len(output_sequences) > 0 and len(self.original_input_seq) > 0:
# Only keep this example if there is at least one output sequence.
self.keep = True
if len(output_sequences) == 0 or len(self.original_input_seq) == 0:
return
# Process the input sequence
self.process_input_seq(params.anonymize,
anonymizer,
anon_tok_to_ent)
# Process the gold sequence
self.process_gold_seq(output_sequences,
nl_to_sql_dict,
self.available_snippets,
params.anonymize,
anonymizer,
anon_tok_to_ent)
def __str__(self):
string = "Original input: " + " ".join(self.original_input_seq) + "\n"
string += "Modified input: " + " ".join(self.input_seq_to_use) + "\n"
string += "Original output: " + " ".join(self.original_gold_query) + "\n"
string += "Modified output: " + " ".join(self.gold_query_to_use) + "\n"
string += "Snippets:\n"
for snippet in self.available_snippets:
string += str(snippet) + "\n"
return string
def length_valid(self, input_limit, output_limit):
return (len(self.input_seq_to_use) < input_limit \
and len(self.gold_query_to_use) < output_limit)