Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

[WordSeg] Add example #526

Merged
merged 3 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ add_subdirectory(BERT-CoLA)
add_subdirectory(GPT2-WikiText2)
add_subdirectory(NeuMF-MovieLens)
add_subdirectory(GPT2-Inference)
add_subdirectory(WordSeg)
9 changes: 9 additions & 0 deletions Examples/WordSeg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_executable(WordSeg
main.swift)
target_link_libraries(WordSeg PRIVATE
TextModels
Datasets)


install(TARGETS WordSeg
DESTINATION bin)
38 changes: 38 additions & 0 deletions Examples/WordSeg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Word Segmentation

This example demonstrates how to train the [word segmentation model][paper]
against the dataset provided in the paper.

A segmental neural language model (SNLM) is instantiated from the library of
standard models. A custom training loop is defined and the training
losses for each epoch are shown.

## Setup

To begin, you'll need the [latest version of Swift for
TensorFlow][s4tf] installed. Make sure you've added the correct version of
`swift` to your path.

To train the model using the full datasets published in the paper, run:

```sh
cd swift-models
swift run -c release WordSeg
```

To train the model using a smaller, unrealistic sample dataset, run:

```sh
cd swift-models
swift run -c release WordSeg Examples/WordSeg/smalldata.txt
```

To run the model with your own dataset, run:

```sh
cd swift-models
swift run -c release WordSeg path/to/training_data.txt [path/to/validation_data.txt [path/to/test_data.txt]]
```

[s4tf]: https://github.com/tensorflow/swift/blob/master/Installation.md
[paper]: https://www.aclweb.org/anthology/P19-1645.pdf
115 changes: 115 additions & 0 deletions Examples/WordSeg/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import ModelSupport
import TensorFlow
import TextModels

// Model flags
let ndim = 512 // Hidden unit size.
// Training flags
let dropoutProb = 0.5 // Dropout rate.
let order = 5 // Power of length penalty.
let maxEpochs = 1000 // Maximum number of training epochs.
let lambd: Float = 0.00075 // Weight of length penalty.
// Lexicon flags.
let maxLength = 10 // Maximum length of a string.
let minFreq = 10 // Minimum frequency of a string.

// Load user-provided data files.
let dataset: WordSegDataset
switch CommandLine.arguments.count {
case 1:
dataset = try WordSegDataset()
case 2:
dataset = try WordSegDataset(training: CommandLine.arguments[1])
case 3:
dataset = try WordSegDataset(
training: CommandLine.arguments[1], validation: CommandLine.arguments[2])
case 4:
dataset = try WordSegDataset(
training: CommandLine.arguments[1], validation: CommandLine.arguments[2],
testing: CommandLine.arguments[3])
default:
usage()
}

let lexicon = Lexicon(
from: dataset.training,
alphabet: dataset.alphabet,
maxLength: maxLength,
minFreq: minFreq
)

let modelParameters = SNLM.Parameters(
ndim: ndim,
dropoutProb: dropoutProb,
chrVocab: dataset.alphabet,
strVocab: lexicon,
order: order
)

var model = SNLM(parameters: modelParameters)

let optimizer = Adam(for: model)

print("Starting training...")

for epoch in 1...10 {
Context.local.learningPhase = .training
var trainingLossSum: Float = 0
var trainingBatchCount = 0
for sentence in dataset.training {
let (loss, gradients) = valueWithGradient(at: model) { model -> Float in
let lattice = model.buildLattice(sentence, maxLen: maxLength)
let score = lattice[sentence.count].semiringScore
let expectedLength = exp(score.logr - score.logp)
let loss = -1 * score.logp + lambd * expectedLength
return loss
}

trainingLossSum += loss
trainingBatchCount += 1
optimizer.update(&model, along: gradients)

if hasNaN(gradients) {
print("Warning: grad has NaN")
}
if hasNaN(model) {
print("Warning: model has NaN")
}
}

print(
"""
[Epoch \(epoch)] \
Loss: \(trainingLossSum / Float(trainingBatchCount))
"""
)
}

func hasNaN<T: KeyPathIterable>(_ t: T) -> Bool {
for kp in t.recursivelyAllKeyPaths(to: Tensor<Float>.self) {
if t[keyPath: kp].isNaN.any() { return true }
}
return false
}

func usage() -> Never {
print(
"\(CommandLine.arguments[0]) path/to/training_data.txt [path/to/validation_data.txt [path/to/test_data.txt]]"
)
exit(1)
}
12 changes: 12 additions & 0 deletions Examples/WordSeg/smalldata.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
hello world
hello world
hello world
hello world
hello world
hello world
hello world
hello world
hello world
hello world
hello world
goodbye world
8 changes: 7 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ let package = Package(
.executable(name: "GPT2-WikiText2", targets: ["GPT2-WikiText2"]),
.executable(name: "NeuMF-MovieLens", targets: ["NeuMF-MovieLens"]),
.executable(name: "CycleGAN", targets: ["CycleGAN"]),
.executable(name: "WordSeg", targets: ["WordSeg"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-protobuf.git", from: "1.7.0"),
Expand Down Expand Up @@ -135,6 +136,11 @@ let package = Package(
name: "pix2pix",
dependencies: ["Batcher", .product(name: "ArgumentParser", package: "swift-argument-parser"), "ModelSupport", "Datasets"],
path: "pix2pix"
)
),
.target(
name: "WordSeg",
dependencies: ["ModelSupport", "TextModels", "Datasets"],
path: "Examples/WordSeg"
)
]
)