From 6d275ddf9fdd67a32b79b93d70fedffe4b156d5c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 10 Nov 2023 14:45:16 +0800 Subject: [PATCH] fixed broken softlinks (#1381) * removed broken softlinks * fixed dependencies * fixed file permission --- icefall/shared/convert-k2-to-openfst.py | 103 +++- icefall/shared/ngram_entropy_pruning.py | 631 +++++++++++++++++++++++- icefall/shared/parse_options.sh | 98 +++- 3 files changed, 829 insertions(+), 3 deletions(-) mode change 120000 => 100755 icefall/shared/convert-k2-to-openfst.py mode change 120000 => 100755 icefall/shared/ngram_entropy_pruning.py mode change 120000 => 100755 icefall/shared/parse_options.sh diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py deleted file mode 120000 index 24efe5eaeb..0000000000 --- a/icefall/shared/convert-k2-to-openfst.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/convert-k2-to-openfst.py \ No newline at end of file diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 0000000000..29a2cd7f7e --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py deleted file mode 120000 index 0e14ac4158..0000000000 --- a/icefall/shared/ngram_entropy_pruning.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/ngram_entropy_pruning.py \ No newline at end of file diff --git a/icefall/shared/ngram_entropy_pruning.py b/icefall/shared/ngram_entropy_pruning.py new file mode 100755 index 0000000000..b1ebee9ea0 --- /dev/null +++ b/icefall/shared/ngram_entropy_pruning.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2021 Johns Hopkins University (Author: Ruizhe Huang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./ngram_entropy_pruning.py \ + -threshold 1e-8 \ + -lm download/lm/4gram.arpa \ + -write-lm download/lm/4gram_pruned_1e8.arpa + +This file is from Kaldi `egs/wsj/s5/utils/lang/ngram_entropy_pruning.py`. +This is an implementation of ``Entropy-based Pruning of Backoff Language Models'' +in the same way as SRILM. +""" + + +import argparse +import gzip +import logging +import math +import re +from collections import OrderedDict, defaultdict +from enum import Enum, unique +from io import StringIO + +parser = argparse.ArgumentParser( + description=""" + Prune an n-gram language model based on the relative entropy + between the original and the pruned model, based on Andreas Stolcke's paper. + An n-gram entry is removed, if the removal causes (training set) perplexity + of the model to increase by less than threshold relative. + + The command takes an arpa file and a pruning threshold as input, + and outputs a pruned arpa file. + """ +) +parser.add_argument("-threshold", type=float, default=1e-6, help="Order of n-gram") +parser.add_argument("-lm", type=str, default=None, help="Path to the input arpa file") +parser.add_argument( + "-write-lm", type=str, default=None, help="Path to output arpa file after pruning" +) +parser.add_argument( + "-minorder", + type=int, + default=1, + help="The minorder parameter limits pruning to ngrams of that length and above.", +) +parser.add_argument( + "-encoding", type=str, default="utf-8", help="Encoding of the arpa file" +) +parser.add_argument( + "-verbose", + type=int, + default=2, + choices=[0, 1, 2, 3, 4, 5], + help="Verbose level, where 0 is most noisy; 5 is most silent", +) +args = parser.parse_args() + +default_encoding = args.encoding +logging.basicConfig( + format="%(asctime)s — %(levelname)s — %(funcName)s:%(lineno)d — %(message)s", + level=args.verbose * 10, +) + + +class Context(dict): + """ + This class stores data for a context h. + It behaves like a python dict object, except that it has several + additional attributes. + """ + + def __init__(self): + super().__init__() + self.log_bo = None + + +class Arpa: + """ + This is a class that implement the data structure of an APRA LM. + It (as well as some other classes) is modified based on the library + by Stefan Fischer: + https://github.com/sfischer13/python-arpa + """ + + UNK = "" + SOS = "" + EOS = "" + FLOAT_NDIGITS = 7 + base = 10 + + @staticmethod + def _check_input(my_input): + if not my_input: + raise ValueError + elif isinstance(my_input, tuple): + return my_input + elif isinstance(my_input, list): + return tuple(my_input) + elif isinstance(my_input, str): + return tuple(my_input.strip().split(" ")) + else: + raise ValueError + + @staticmethod + def _check_word(input_word): + if not isinstance(input_word, str): + raise ValueError + if " " in input_word: + raise ValueError + + def _replace_unks(self, words): + return tuple((w if w in self else self._unk) for w in words) + + def __init__(self, path=None, encoding=None, unk=None): + self._counts = OrderedDict() + self._ngrams = ( + OrderedDict() + ) # Use self._ngrams[len(h)][h][w] for saving the entry of (h,w) + self._vocabulary = set() + if unk is None: + self._unk = self.UNK + + if path is not None: + self.loadf(path, encoding) + + def __contains__(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + return h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h] + + def contains_word(self, word): + self._check_word(word) + return word in self._vocabulary + + def add_count(self, order, count): + self._counts[order] = count + self._ngrams[order - 1] = defaultdict(Context) + + def update_counts(self): + for order in range(1, self.order() + 1): + count = sum([len(wlist) for _, wlist in self._ngrams[order - 1].items()]) + if count > 0: + self._counts[order] = count + + def add_entry(self, ngram, p, bo=None, order=None): + # Note: ngram is a tuple of strings, e.g. ("w1", "w2", "w3") + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + + # Note that p and bo here are in fact in the log domain (self.base = 10) + h_context = self._ngrams[len(h)][h] + h_context[w] = p + if bo is not None: + self._ngrams[len(ngram)][ngram].log_bo = bo + + for word in ngram: + self._vocabulary.add(word) + + def counts(self): + return sorted(self._counts.items()) + + def order(self): + return max(self._counts.keys(), default=None) + + def vocabulary(self, sort=True): + if sort: + return sorted(self._vocabulary) + else: + return self._vocabulary + + def _entries(self, order): + return ( + self._entry(h, w) + for h, wlist in self._ngrams[order - 1].items() + for w in wlist + ) + + def _entry(self, h, w): + # return the entry for the ngram (h, w) + ngram = h + (w,) + log_p = self._ngrams[len(h)][h][w] + log_bo = self._log_bo(ngram) + if log_bo is not None: + return ( + round(log_p, self.FLOAT_NDIGITS), + ngram, + round(log_bo, self.FLOAT_NDIGITS), + ) + else: + return round(log_p, self.FLOAT_NDIGITS), ngram + + def _log_bo(self, ngram): + if len(ngram) in self._ngrams and ngram in self._ngrams[len(ngram)]: + return self._ngrams[len(ngram)][ngram].log_bo + else: + return None + + def _log_p(self, ngram): + h = ngram[:-1] # h is a tuple + w = ngram[-1] # w is a string/word + if h in self._ngrams[len(h)] and w in self._ngrams[len(h)][h]: + return self._ngrams[len(h)][h][w] + else: + return None + + def log_p_raw(self, ngram): + log_p = self._log_p(ngram) + if log_p is not None: + return log_p + else: + if len(ngram) == 1: + raise KeyError + else: + log_bo = self._log_bo(ngram[:-1]) + if log_bo is None: + log_bo = 0 + return log_bo + self.log_p_raw(ngram[1:]) + + def log_joint_prob(self, sequence): + # Compute the joint prob of the sequence based on the chain rule + # Note that sequence should be a tuple of strings + # + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/LM.cc#L527 + + log_joint_p = 0 + seq = sequence + while len(seq) > 0: + log_joint_p += self.log_p_raw(seq) + seq = seq[:-1] + + # If we're computing the marginal probability of the unigram + # context we have to look up instead since the former + # has prob = 0. + if len(seq) == 1 and seq[0] == self.SOS: + seq = (self.EOS,) + + return log_joint_p + + def set_new_context(self, h): + old_context = self._ngrams[len(h)][h] + self._ngrams[len(h)][h] = Context() + return old_context + + def log_p(self, ngram): + words = self._check_input(ngram) + if self._unk: + words = self._replace_unks(words) + return self.log_p_raw(words) + + def log_s(self, sentence, sos=SOS, eos=EOS): + words = self._check_input(sentence) + if self._unk: + words = self._replace_unks(words) + if sos: + words = (sos,) + words + if eos: + words = words + (eos,) + result = sum(self.log_p_raw(words[:i]) for i in range(1, len(words) + 1)) + if sos: + result = result - self.log_p_raw(words[:1]) + return result + + def p(self, ngram): + return self.base ** self.log_p(ngram) + + def s(self, sentence): + return self.base ** self.log_s(sentence) + + def write(self, fp): + fp.write("\n\\data\\\n") + for order, count in self.counts(): + fp.write("ngram {}={}\n".format(order, count)) + fp.write("\n") + for order, _ in self.counts(): + fp.write("\\{}-grams:\n".format(order)) + for e in self._entries(order): + prob = e[0] + ngram = " ".join(e[1]) + if len(e) == 2: + fp.write("{}\t{}\n".format(prob, ngram)) + elif len(e) == 3: + backoff = e[2] + fp.write("{}\t{}\t{}\n".format(prob, ngram, backoff)) + else: + raise ValueError + fp.write("\n") + fp.write("\\end\\\n") + + +class ArpaParser: + """ + This is a class that implement a parser of an arpa file + """ + + @unique + class State(Enum): + DATA = 1 + COUNT = 2 + HEADER = 3 + ENTRY = 4 + + re_count = re.compile(r"^ngram (\d+)=(\d+)$") + re_header = re.compile(r"^\\(\d+)-grams:$") + re_entry = re.compile( + "^(-?\\d+(\\.\\d+)?([eE]-?\\d+)?)" + "\t" + "(\\S+( \\S+)*)" + "(\t((-?\\d+(\\.\\d+)?)([eE]-?\\d+)?))?$" + ) + + def _parse(self, fp): + self._result = [] + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + for line in fp: + line = line.strip() + if self._state == self.State.DATA: + self._data(line) + elif self._state == self.State.COUNT: + self._count(line) + elif self._state == self.State.HEADER: + self._header(line) + elif self._state == self.State.ENTRY: + self._entry(line) + if self._state != self.State.DATA: + raise Exception(line) + return self._result + + def _data(self, line): + if line == "\\data\\": + self._state = self.State.COUNT + self._tmp_model = Arpa() + else: + pass # skip comment line + + def _count(self, line): + match = self.re_count.match(line) + if match: + order = match.group(1) + count = match.group(2) + self._tmp_model.add_count(int(order), int(count)) + elif not line: + self._state = self.State.HEADER # there are no counts + else: + raise Exception(line) + + def _header(self, line): + match = self.re_header.match(line) + if match: + self._state = self.State.ENTRY + self._tmp_order = int(match.group(1)) + elif line == "\\end\\": + self._result.append(self._tmp_model) + self._state = self.State.DATA + self._tmp_model = None + self._tmp_order = None + elif not line: + pass # skip empty line + else: + raise Exception(line) + + def _entry(self, line): + match = self.re_entry.match(line) + if match: + p = self._float_or_int(match.group(1)) + ngram = tuple(match.group(4).split(" ")) + bo_match = match.group(7) + bo = self._float_or_int(bo_match) if bo_match else None + self._tmp_model.add_entry(ngram, p, bo, self._tmp_order) + elif not line: + self._state = self.State.HEADER # last entry + else: + raise Exception(line) + + @staticmethod + def _float_or_int(s): + f = float(s) + i = int(f) + if str(i) == s: # don't drop trailing ".0" + return i + else: + return f + + def load(self, fp): + """Deserialize fp (a file-like object) to a Python object.""" + return self._parse(fp) + + def loadf(self, path, encoding=None): + """Deserialize path (.arpa, .gz) to a Python object.""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + else: + with open(path, mode="rt", encoding=encoding) as f: + return self.load(f) + + def loads(self, s): + """Deserialize s (a str) to a Python object.""" + with StringIO(s) as f: + return self.load(f) + + def dump(self, obj, fp): + """Serialize obj to fp (a file-like object) in ARPA format.""" + obj.write(fp) + + def dumpf(self, obj, path, encoding=None): + """Serialize obj to path in ARPA format (.arpa, .gz).""" + path = str(path) + if path.endswith(".gz"): + with gzip.open(path, mode="wt", encoding=encoding) as f: + return self.dump(obj, f) + else: + with open(path, mode="wt", encoding=encoding) as f: + self.dump(obj, f) + + def dumps(self, obj): + """Serialize obj to an ARPA formatted str.""" + with StringIO() as f: + self.dump(obj, f) + return f.getvalue() + + +def add_log_p(prev_log_sum, log_p, base): + return math.log(base**log_p + base**prev_log_sum, base) + + +def compute_numerator_denominator(lm, h): + log_sum_seen_h = -math.inf + log_sum_seen_h_lower = -math.inf + base = lm.base + for w, log_p in lm._ngrams[len(h)][h].items(): + log_sum_seen_h = add_log_p(log_sum_seen_h, log_p, base) + + ngram = h + (w,) + log_p_lower = lm.log_p_raw(ngram[1:]) + log_sum_seen_h_lower = add_log_p(log_sum_seen_h_lower, log_p_lower, base) + + numerator = 1.0 - base**log_sum_seen_h + denominator = 1.0 - base**log_sum_seen_h_lower + return numerator, denominator + + +def prune(lm, threshold, minorder): + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/lm/src/NgramLM.cc#L2330 + + for i in range( + lm.order(), max(minorder - 1, 1), -1 + ): # i is the order of the ngram (h, w) + logging.info("processing %d-grams ..." % i) + count_pruned_ngrams = 0 + + h_dict = lm._ngrams[i - 1] + for h in list(h_dict.keys()): + # old backoff weight, BOW(h) + log_bow = lm._log_bo(h) + if log_bow is None: + log_bow = 0 + + # Compute numerator and denominator of the backoff weight, + # so that we can quickly compute the BOW adjustment due to + # leaving out one prob. + numerator, denominator = compute_numerator_denominator(lm, h) + + # assert abs(math.log(numerator, lm.base) - math.log(denominator, lm.base) - h_dict[h].log_bo) < 1e-5 + + # Compute the marginal probability of the context, P(h) + h_log_p = lm.log_joint_prob(h) + + all_pruned = True + pruned_w_set = set() + + for w, log_p in h_dict[h].items(): + ngram = h + (w,) + + # lower-order estimate for ngramProb, P(w|h') + backoff_prob = lm.log_p_raw(ngram[1:]) + + # Compute BOW after removing ngram, BOW'(h) + new_log_bow = math.log( + numerator + lm.base**log_p, lm.base + ) - math.log(denominator + lm.base**backoff_prob, lm.base) + + # Compute change in entropy due to removal of ngram + delta_prob = backoff_prob + new_log_bow - log_p + delta_entropy = -(lm.base**h_log_p) * ( + (lm.base**log_p) * delta_prob + + numerator * (new_log_bow - log_bow) + ) + + # compute relative change in model (training set) perplexity + perp_change = lm.base**delta_entropy - 1.0 + + pruned = threshold > 0 and perp_change < threshold + + # Make sure we don't prune ngrams whose backoff nodes are needed + if ( + pruned + and len(ngram) in lm._ngrams + and len(lm._ngrams[len(ngram)][ngram]) > 0 + ): + pruned = False + + logging.debug( + "CONTEXT " + + str(h) + + " WORD " + + w + + " CONTEXTPROB %f " % h_log_p + + " OLDPROB %f " % log_p + + " NEWPROB %f " % (backoff_prob + new_log_bow) + + " DELTA-H %f " % delta_entropy + + " DELTA-LOGP %f " % delta_prob + + " PPL-CHANGE %f " % perp_change + + " PRUNED " + + str(pruned) + ) + + if pruned: + pruned_w_set.add(w) + count_pruned_ngrams += 1 + else: + all_pruned = False + + # If we removed all ngrams for this context we can + # remove the context itself, but only if the present + # context is not a prefix to a longer one. + if all_pruned and len(pruned_w_set) == len(h_dict[h]): + del h_dict[ + h + ] # this context h is no longer needed, as its ngram prob is stored at its own context h' + elif len(pruned_w_set) > 0: + # The pruning for this context h is actually done here + old_context = lm.set_new_context(h) + + for w, p_w in old_context.items(): + if w not in pruned_w_set: + lm.add_entry( + h + (w,), p_w + ) # the entry hw is stored at the context h + + # We need to recompute the back-off weight, but + # this can only be done after completing the pruning + # of the lower-order ngrams. + # Reference: + # https://github.com/BitSpeech/SRILM/blob/d571a4424fb0cf08b29fbfccfddd092ea969eae3/flm/src/FNgramLM.cc#L2124 + + logging.info("pruned %d %d-grams" % (count_pruned_ngrams, i)) + + # recompute backoff weights + for i in range( + max(minorder - 1, 1) + 1, lm.order() + 1 + ): # be careful of this order: from low- to high-order + for h in lm._ngrams[i - 1]: + numerator, denominator = compute_numerator_denominator(lm, h) + new_log_bow = math.log(numerator, lm.base) - math.log(denominator, lm.base) + lm._ngrams[len(h)][h].log_bo = new_log_bow + + # update counts + lm.update_counts() + + return + + +def check_h_is_valid(lm, h): + sum_under_h = sum( + [lm.base ** lm.log_p_raw(h + (w,)) for w in lm.vocabulary(sort=False)] + ) + if abs(sum_under_h - 1.0) > 1e-6: + logging.info("warning: %s %f" % (str(h), sum_under_h)) + return False + else: + return True + + +def validate_lm(lm): + # sanity check if the conditional probability sums to one under each context h + for i in range(lm.order(), 0, -1): # i is the order of the ngram (h, w) + logging.info("validating %d-grams ..." % i) + h_dict = lm._ngrams[i - 1] + for h in h_dict.keys(): + check_h_is_valid(lm, h) + + +def compare_two_apras(path1, path2): + pass + + +if __name__ == "__main__": + # load an arpa file + logging.info("Loading the arpa file from %s" % args.lm) + parser = ArpaParser() + models = parser.loadf(args.lm, encoding=default_encoding) + lm = models[0] # ARPA files may contain several models. + logging.info("Stats before pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + + # prune it, the language model will be modified in-place + logging.info("Start pruning the model with threshold=%.3E..." % args.threshold) + prune(lm, args.threshold, args.minorder) + + # validate_lm(lm) + + # write the arpa language model to a file + logging.info("Stats after pruning:") + for i, cnt in lm.counts(): + logging.info("ngram %d=%d" % (i, cnt)) + logging.info("Saving the pruned arpa file to %s" % args.write_lm) + parser.dumpf(lm, args.write_lm, encoding=default_encoding) + logging.info("Done.") diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh deleted file mode 120000 index e4665e7de2..0000000000 --- a/icefall/shared/parse_options.sh +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file diff --git a/icefall/shared/parse_options.sh b/icefall/shared/parse_options.sh new file mode 100755 index 0000000000..71fb9e5ea1 --- /dev/null +++ b/icefall/shared/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0.