Skip to content

Commit 8d9c237

Browse files
committed
clean up label handling
1 parent 15f1e80 commit 8d9c237

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

utils.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,22 @@
2424
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-2-raw/valid.txt"},
2525
'simplebooks-92-raw': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/train.txt",
2626
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/valid.txt"},
27-
'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt",
28-
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.txt",
29-
'labels': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt",
30-
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.labels.txt",
31-
'convert': {'pos': 0, 'neg': 1}}},
32-
'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt",
33-
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/valid.txt",
34-
'labels': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt",
35-
'valid': "https://s3.amazonaws.com/datasets.huggingface.co/trec/valid.labels.txt",
36-
'convert': {'NUM': 0, 'LOC': 1, 'HUM': 2, 'DESC': 3, 'ENTY': 4, 'ABBR': 5}}},
27+
'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt",
28+
'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.txt"},
29+
'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt",
30+
'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt"},
31+
}
32+
33+
DATASETS_LABELS_URL = {
34+
'imdb': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt",
35+
'test': "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.labels.txt"},
36+
'trec': {'train': "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt",
37+
'test': "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt"},
38+
}
39+
40+
DATASETS_LABELS_CONVERSION = {
41+
'imdb': {'pos': 0, 'neg': 1},
42+
'trec': {'NUM': 0, 'LOC': 1, 'HUM': 2, 'DESC': 3, 'ENTY': 4, 'ABBR': 5},
3743
}
3844

3945
PRETRAINED_MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/naacl-2019-tutorial/"
@@ -71,9 +77,9 @@ def add_logging_and_checkpoint_saving(trainer, evaluator, metrics, model, optimi
7177
# Add tensorboard logging with training and evaluation metrics
7278
tb_logger = TensorboardLogger(log_dir=None)
7379
tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=[prefix + "loss"]),
74-
event_name=Events.ITERATION_COMPLETED)
80+
event_name=Events.ITERATION_COMPLETED)
7581
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer),
76-
event_name=Events.ITERATION_STARTED)
82+
event_name=Events.ITERATION_STARTED)
7783
@evaluator.on(Events.COMPLETED)
7884
def tb_log_metrics(engine):
7985
for name in metrics.keys():
@@ -97,29 +103,31 @@ def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cach
97103
else:
98104
# If the dataset is in our list of DATASETS_URL, use this url, otherwise, look for 'train.txt' and 'valid.txt' files
99105
if dataset_dir in DATASETS_URL:
100-
dataset_dir = DATASETS_URL[dataset_dir]
106+
dataset_map = DATASETS_URL[dataset_dir]
101107
else:
102-
dataset_dir = {'train': os.path.join(dataset_dir, 'train.txt'),
108+
dataset_map = {'train': os.path.join(dataset_dir, 'train.txt'),
103109
'valid': os.path.join(dataset_dir, 'valid.txt')}
104110

105111
logger.info("Get dataset from %s", dataset_dir)
106112
# Download and read dataset and replace a few token for compatibility with the Bert tokenizer we are using
107113
dataset = {}
108-
for split_name in ['train', 'valid']:
109-
dataset_file = cached_path(dataset_dir[split_name])
114+
for split_name in dataset_map.keys():
115+
dataset_file = cached_path(dataset_map[split_name])
110116
with open(dataset_file, "r", encoding="utf-8") as f:
111117
all_lines = f.readlines()
112118
dataset[split_name] = [
113-
line.strip(' ').replace('\n', '[SEP]' if not with_labels else '').replace('<unk>', '[UNK]') for line in tqdm(all_lines)]
119+
line.strip(' ').replace('<unk>', '[UNK]').replace('\n', '[SEP]' if not with_labels else '')
120+
for line in tqdm(all_lines)]
114121

115-
# Download and read labels if needed, convert labels names to integers
122+
# If we have labels, download and and convert labels in integers
116123
labels = {}
117124
if with_labels:
118-
for split_name in ['train', 'valid']:
119-
dataset_file = cached_path(dataset_dir['labels'][split_name])
125+
label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir]
126+
for split_name in DATASETS_LABELS_URL[dataset_dir]:
127+
dataset_file = cached_path(dataset_map['labels'][split_name])
120128
with open(dataset_file, "r", encoding="utf-8") as f:
121129
all_lines = f.readlines()
122-
labels[split_name] = [dataset_dir['labels']['convert'][line.strip()] for line in tqdm(all_lines)]
130+
labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)]
123131

124132
# Tokenize and encode the dataset
125133
logger.info("Tokenize and encode the dataset")

0 commit comments

Comments
 (0)