]> git.openstreetmap.org Git - nominatim.git/blobdiff - nominatim/tools/special_phrases/sp_csv_loader.py
Merge remote-tracking branch 'upstream/master'
[nominatim.git] / nominatim / tools / special_phrases / sp_csv_loader.py
index b7b24a7dff2cf3c70d148e5d97bfcc6e2d1a9b86..400f9fa91aa3efec500a8e40b3e7f1df08e609bf 100644 (file)
@@ -1,51 +1,45 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This file is part of Nominatim. (https://nominatim.org)
+#
+# Copyright (C) 2022 by the Nominatim developer community.
+# For a full list of authors see the git log.
 """
     Module containing the SPCsvLoader class.
 
     The class allows to load phrases from a csv file.
 """
+from typing import Iterable
 import csv
 import os
-from collections.abc import Iterator
 from nominatim.tools.special_phrases.special_phrase import SpecialPhrase
 from nominatim.errors import UsageError
 
-class SPCsvLoader(Iterator):
+class SPCsvLoader:
     """
         Handles loading of special phrases from external csv file.
     """
-    def __init__(self, csv_path):
-        super().__init__()
+    def __init__(self, csv_path: str) -> None:
         self.csv_path = csv_path
-        self.has_been_read = False
 
-    def __next__(self):
-        if self.has_been_read:
-            raise StopIteration()
 
-        self.has_been_read = True
-        self.check_csv_validity()
-        return self.parse_csv()
-
-    def parse_csv(self):
-        """
-            Open and parse the given csv file.
+    def generate_phrases(self) -> Iterable[SpecialPhrase]:
+        """ Open and parse the given csv file.
             Create the corresponding SpecialPhrases.
         """
-        phrases = set()
+        self._check_csv_validity()
 
-        with open(self.csv_path) as file:
-            reader = csv.DictReader(file, delimiter=',')
+        with open(self.csv_path, encoding='utf-8') as fd:
+            reader = csv.DictReader(fd, delimiter=',')
             for row in reader:
-                phrases.add(
-                    SpecialPhrase(row['phrase'], row['class'], row['type'], row['operator'])
-                )
-        return phrases
+                yield SpecialPhrase(row['phrase'], row['class'], row['type'], row['operator'])
+
 
-    def check_csv_validity(self):
+    def _check_csv_validity(self) -> None:
         """
             Check that the csv file has the right extension.
         """
         _, extension = os.path.splitext(self.csv_path)
 
         if extension != '.csv':
-            raise UsageError('The file {} is not a csv file.'.format(self.csv_path))
+            raise UsageError(f'The file {self.csv_path} is not a csv file.')