24
24
'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-2-raw/valid.txt" },
25
25
'simplebooks-92-raw' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/train.txt" ,
26
26
'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/simplebooks-92-raw/valid.txt" },
27
- 'imdb' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt" ,
28
- 'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.txt" ,
29
- 'labels' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt" ,
30
- 'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/valid.labels.txt" ,
31
- 'convert' : {'pos' : 0 , 'neg' : 1 }}},
32
- 'trec' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt" ,
33
- 'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/valid.txt" ,
34
- 'labels' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt" ,
35
- 'valid' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/valid.labels.txt" ,
36
- 'convert' : {'NUM' : 0 , 'LOC' : 1 , 'HUM' : 2 , 'DESC' : 3 , 'ENTY' : 4 , 'ABBR' : 5 }}},
27
+ 'imdb' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.txt" ,
28
+ 'test' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.txt" },
29
+ 'trec' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.txt" ,
30
+ 'test' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.txt" },
31
+ }
32
+
33
+ DATASETS_LABELS_URL = {
34
+ 'imdb' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/train.labels.txt" ,
35
+ 'test' : "https://s3.amazonaws.com/datasets.huggingface.co/aclImdb/test.labels.txt" },
36
+ 'trec' : {'train' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/train.labels.txt" ,
37
+ 'test' : "https://s3.amazonaws.com/datasets.huggingface.co/trec/test.labels.txt" },
38
+ }
39
+
40
+ DATASETS_LABELS_CONVERSION = {
41
+ 'imdb' : {'pos' : 0 , 'neg' : 1 },
42
+ 'trec' : {'NUM' : 0 , 'LOC' : 1 , 'HUM' : 2 , 'DESC' : 3 , 'ENTY' : 4 , 'ABBR' : 5 },
37
43
}
38
44
39
45
PRETRAINED_MODEL_URL = "https://s3.amazonaws.com/models.huggingface.co/naacl-2019-tutorial/"
@@ -71,9 +77,9 @@ def add_logging_and_checkpoint_saving(trainer, evaluator, metrics, model, optimi
71
77
# Add tensorboard logging with training and evaluation metrics
72
78
tb_logger = TensorboardLogger (log_dir = None )
73
79
tb_logger .attach (trainer , log_handler = OutputHandler (tag = "training" , metric_names = [prefix + "loss" ]),
74
- event_name = Events .ITERATION_COMPLETED )
80
+ event_name = Events .ITERATION_COMPLETED )
75
81
tb_logger .attach (trainer , log_handler = OptimizerParamsHandler (optimizer ),
76
- event_name = Events .ITERATION_STARTED )
82
+ event_name = Events .ITERATION_STARTED )
77
83
@evaluator .on (Events .COMPLETED )
78
84
def tb_log_metrics (engine ):
79
85
for name in metrics .keys ():
@@ -97,29 +103,31 @@ def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cach
97
103
else :
98
104
# If the dataset is in our list of DATASETS_URL, use this url, otherwise, look for 'train.txt' and 'valid.txt' files
99
105
if dataset_dir in DATASETS_URL :
100
- dataset_dir = DATASETS_URL [dataset_dir ]
106
+ dataset_map = DATASETS_URL [dataset_dir ]
101
107
else :
102
- dataset_dir = {'train' : os .path .join (dataset_dir , 'train.txt' ),
108
+ dataset_map = {'train' : os .path .join (dataset_dir , 'train.txt' ),
103
109
'valid' : os .path .join (dataset_dir , 'valid.txt' )}
104
110
105
111
logger .info ("Get dataset from %s" , dataset_dir )
106
112
# Download and read dataset and replace a few token for compatibility with the Bert tokenizer we are using
107
113
dataset = {}
108
- for split_name in [ 'train' , 'valid' ] :
109
- dataset_file = cached_path (dataset_dir [split_name ])
114
+ for split_name in dataset_map . keys () :
115
+ dataset_file = cached_path (dataset_map [split_name ])
110
116
with open (dataset_file , "r" , encoding = "utf-8" ) as f :
111
117
all_lines = f .readlines ()
112
118
dataset [split_name ] = [
113
- line .strip (' ' ).replace ('\n ' , '[SEP]' if not with_labels else '' ).replace ('<unk>' , '[UNK]' ) for line in tqdm (all_lines )]
119
+ line .strip (' ' ).replace ('<unk>' , '[UNK]' ).replace ('\n ' , '[SEP]' if not with_labels else '' )
120
+ for line in tqdm (all_lines )]
114
121
115
- # Download and read labels if needed, convert labels names to integers
122
+ # If we have labels, download and and convert labels in integers
116
123
labels = {}
117
124
if with_labels :
118
- for split_name in ['train' , 'valid' ]:
119
- dataset_file = cached_path (dataset_dir ['labels' ][split_name ])
125
+ label_conversion_map = DATASETS_LABELS_CONVERSION [dataset_dir ]
126
+ for split_name in DATASETS_LABELS_URL [dataset_dir ]:
127
+ dataset_file = cached_path (dataset_map ['labels' ][split_name ])
120
128
with open (dataset_file , "r" , encoding = "utf-8" ) as f :
121
129
all_lines = f .readlines ()
122
- labels [split_name ] = [dataset_dir [ 'labels' ][ 'convert' ] [line .strip ()] for line in tqdm (all_lines )]
130
+ labels [split_name ] = [label_conversion_map [line .strip ()] for line in tqdm (all_lines )]
123
131
124
132
# Tokenize and encode the dataset
125
133
logger .info ("Tokenize and encode the dataset" )
0 commit comments