diff --git a/Examples/CMakeLists.txt b/Examples/CMakeLists.txt index 5f384f67038..b93a383189b 100644 --- a/Examples/CMakeLists.txt +++ b/Examples/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(BERT-CoLA) add_subdirectory(GPT2-WikiText2) add_subdirectory(NeuMF-MovieLens) add_subdirectory(GPT2-Inference) +add_subdirectory(WordSeg) diff --git a/Examples/WordSeg/CMakeLists.txt b/Examples/WordSeg/CMakeLists.txt new file mode 100644 index 00000000000..2674e2f0618 --- /dev/null +++ b/Examples/WordSeg/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(WordSeg + main.swift) +target_link_libraries(WordSeg PRIVATE + TextModels + Datasets) + + +install(TARGETS WordSeg + DESTINATION bin) diff --git a/Examples/WordSeg/README.md b/Examples/WordSeg/README.md new file mode 100644 index 00000000000..f928304292d --- /dev/null +++ b/Examples/WordSeg/README.md @@ -0,0 +1,38 @@ +# Word Segmentation + +This example demonstrates how to train the [word segmentation model][paper] +against the dataset provided in the paper. + +A segmental neural language model (SNLM) is instantiated from the library of +standard models. A custom training loop is defined and the training +losses for each epoch are shown. + +## Setup + +To begin, you'll need the [latest version of Swift for +TensorFlow][s4tf] installed. Make sure you've added the correct version of +`swift` to your path. + +To train the model using the full datasets published in the paper, run: + +```sh +cd swift-models +swift run -c release WordSeg +``` + +To train the model using a smaller, unrealistic sample dataset, run: + +```sh +cd swift-models +swift run -c release WordSeg Examples/WordSeg/smalldata.txt +``` + +To run the model with your own dataset, run: + +```sh +cd swift-models +swift run -c release WordSeg path/to/training_data.txt [path/to/validation_data.txt [path/to/test_data.txt]] +``` + +[s4tf]: https://github.com/tensorflow/swift/blob/master/Installation.md +[paper]: https://www.aclweb.org/anthology/P19-1645.pdf diff --git a/Examples/WordSeg/main.swift b/Examples/WordSeg/main.swift new file mode 100644 index 00000000000..2d05d71b133 --- /dev/null +++ b/Examples/WordSeg/main.swift @@ -0,0 +1,115 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Datasets +import ModelSupport +import TensorFlow +import TextModels + +// Model flags +let ndim = 512 // Hidden unit size. +// Training flags +let dropoutProb = 0.5 // Dropout rate. +let order = 5 // Power of length penalty. +let maxEpochs = 1000 // Maximum number of training epochs. +let lambd: Float = 0.00075 // Weight of length penalty. +// Lexicon flags. +let maxLength = 10 // Maximum length of a string. +let minFreq = 10 // Minimum frequency of a string. + +// Load user-provided data files. +let dataset: WordSegDataset +switch CommandLine.arguments.count { +case 1: + dataset = try WordSegDataset() +case 2: + dataset = try WordSegDataset(training: CommandLine.arguments[1]) +case 3: + dataset = try WordSegDataset( + training: CommandLine.arguments[1], validation: CommandLine.arguments[2]) +case 4: + dataset = try WordSegDataset( + training: CommandLine.arguments[1], validation: CommandLine.arguments[2], + testing: CommandLine.arguments[3]) +default: + usage() +} + +let lexicon = Lexicon( + from: dataset.training, + alphabet: dataset.alphabet, + maxLength: maxLength, + minFreq: minFreq +) + +let modelParameters = SNLM.Parameters( + ndim: ndim, + dropoutProb: dropoutProb, + chrVocab: dataset.alphabet, + strVocab: lexicon, + order: order +) + +var model = SNLM(parameters: modelParameters) + +let optimizer = Adam(for: model) + +print("Starting training...") + +for epoch in 1...10 { + Context.local.learningPhase = .training + var trainingLossSum: Float = 0 + var trainingBatchCount = 0 + for sentence in dataset.training { + let (loss, gradients) = valueWithGradient(at: model) { model -> Float in + let lattice = model.buildLattice(sentence, maxLen: maxLength) + let score = lattice[sentence.count].semiringScore + let expectedLength = exp(score.logr - score.logp) + let loss = -1 * score.logp + lambd * expectedLength + return loss + } + + trainingLossSum += loss + trainingBatchCount += 1 + optimizer.update(&model, along: gradients) + + if hasNaN(gradients) { + print("Warning: grad has NaN") + } + if hasNaN(model) { + print("Warning: model has NaN") + } + } + + print( + """ + [Epoch \(epoch)] \ + Loss: \(trainingLossSum / Float(trainingBatchCount)) + """ + ) +} + +func hasNaN(_ t: T) -> Bool { + for kp in t.recursivelyAllKeyPaths(to: Tensor.self) { + if t[keyPath: kp].isNaN.any() { return true } + } + return false +} + +func usage() -> Never { + print( + "\(CommandLine.arguments[0]) path/to/training_data.txt [path/to/validation_data.txt [path/to/test_data.txt]]" + ) + exit(1) +} diff --git a/Examples/WordSeg/smalldata.txt b/Examples/WordSeg/smalldata.txt new file mode 100644 index 00000000000..0ba96f93cfe --- /dev/null +++ b/Examples/WordSeg/smalldata.txt @@ -0,0 +1,12 @@ +hello world +hello world +hello world +hello world +hello world +hello world +hello world +hello world +hello world +hello world +hello world +goodbye world diff --git a/Package.swift b/Package.swift index 23641282dae..5167e49a6e4 100644 --- a/Package.swift +++ b/Package.swift @@ -35,6 +35,7 @@ let package = Package( .executable(name: "GPT2-WikiText2", targets: ["GPT2-WikiText2"]), .executable(name: "NeuMF-MovieLens", targets: ["NeuMF-MovieLens"]), .executable(name: "CycleGAN", targets: ["CycleGAN"]), + .executable(name: "WordSeg", targets: ["WordSeg"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-protobuf.git", from: "1.7.0"), @@ -135,6 +136,11 @@ let package = Package( name: "pix2pix", dependencies: ["Batcher", .product(name: "ArgumentParser", package: "swift-argument-parser"), "ModelSupport", "Datasets"], path: "pix2pix" - ) + ), + .target( + name: "WordSeg", + dependencies: ["ModelSupport", "TextModels", "Datasets"], + path: "Examples/WordSeg" + ) ] )