|
22 | 22 | Union,
|
23 | 23 | )
|
24 | 24 |
|
25 |
| -try: |
26 |
| - from tqdm.auto import tqdm |
27 |
| -except ImportError: |
28 |
| - tqdm = None |
| 25 | +from tqdm.auto import tqdm |
29 | 26 |
|
30 | 27 | T = TypeVar("T")
|
31 | 28 | IterableType = TypeVar("IterableType")
|
@@ -105,144 +102,41 @@ def close(self) -> None:
|
105 | 102 | pass
|
106 | 103 |
|
107 | 104 |
|
108 |
| -class SimpleProgress(Iterable[IterableType]): |
109 |
| - def __init__( |
110 |
| - self, |
111 |
| - iterable: Optional[Iterable[IterableType]] = None, |
112 |
| - desc: Optional[str] = None, |
113 |
| - total: Optional[int] = None, |
114 |
| - file: Optional[TextIO] = None, |
115 |
| - mininterval: float = 0.5, |
116 |
| - ) -> None: |
117 |
| - """ |
118 |
| - Simple progress output used when tqdm is unavailable. |
119 |
| - Same as tqdm, output to stderr channel. |
120 |
| - If you want to do nested Progressbars with simple progress |
121 |
| - the parent progress bar should be used as a context |
122 |
| - (i.e. with statement) and the nested progress bar should be |
123 |
| - created inside this context. |
124 |
| - """ |
125 |
| - self.cur = 0 |
126 |
| - self.iterable = iterable |
127 |
| - self.total = total |
128 |
| - if total is None and hasattr(iterable, "__len__"): |
129 |
| - self.total = len(cast(Sized, iterable)) |
130 |
| - |
131 |
| - self.desc = desc |
132 |
| - |
133 |
| - file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr) |
134 |
| - self.file: DisableErrorIOWrapper = file_wrapper |
135 |
| - |
136 |
| - self.mininterval = mininterval |
137 |
| - self.last_print_t = 0.0 |
138 |
| - self.closed = False |
139 |
| - self._is_parent = False |
140 |
| - |
141 |
| - def __enter__(self) -> "SimpleProgress[IterableType]": |
142 |
| - self._is_parent = True |
143 |
| - self._refresh() |
144 |
| - return self |
145 |
| - |
146 |
| - def __exit__( |
147 |
| - self, |
148 |
| - exc_type: Union[Type[BaseException], None], |
149 |
| - exc_value: Union[BaseException, None], |
150 |
| - exc_traceback: Union[TracebackType, None], |
151 |
| - ) -> Literal[False]: |
152 |
| - self.close() |
153 |
| - return False |
154 |
| - |
155 |
| - def __iter__(self) -> Iterator[IterableType]: |
156 |
| - if self.closed or not self.iterable: |
157 |
| - return |
158 |
| - self._refresh() |
159 |
| - for it in cast(Iterable[IterableType], self.iterable): |
160 |
| - yield it |
161 |
| - self.update() |
162 |
| - self.close() |
163 |
| - |
164 |
| - def _refresh(self) -> None: |
165 |
| - progress_str = self.desc + ": " if self.desc else "" |
166 |
| - if self.total: |
167 |
| - # e.g., progress: 60% 3/5 |
168 |
| - progress_str += ( |
169 |
| - f"{100 * self.cur // cast(int, self.total)}%" |
170 |
| - f" {self.cur}/{cast(int, self.total)}" |
171 |
| - ) |
172 |
| - else: |
173 |
| - # e.g., progress: ..... |
174 |
| - progress_str += "." * self.cur |
175 |
| - end = "\n" if self._is_parent else "" |
176 |
| - print("\r" + progress_str, end=end, file=self.file) |
177 |
| - |
178 |
| - def update(self, amount: int = 1) -> None: |
179 |
| - if self.closed: |
180 |
| - return |
181 |
| - self.cur += amount |
182 |
| - |
183 |
| - cur_t = time() |
184 |
| - if cur_t - self.last_print_t >= self.mininterval: |
185 |
| - self._refresh() |
186 |
| - self.last_print_t = cur_t |
187 |
| - |
188 |
| - def close(self) -> None: |
189 |
| - if not self.closed and not self._is_parent: |
190 |
| - self._refresh() |
191 |
| - print(file=self.file) # end with new line |
192 |
| - self.closed = True |
193 |
| - |
194 |
| - |
195 | 105 | @typing.overload
|
196 | 106 | def progress(
|
197 | 107 | iterable: None = None,
|
198 | 108 | desc: Optional[str] = None,
|
199 | 109 | total: Optional[int] = None,
|
200 |
| - use_tqdm: bool = True, |
201 | 110 | file: Optional[TextIO] = None,
|
202 | 111 | mininterval: float = 0.5,
|
203 | 112 | **kwargs: object,
|
204 |
| -) -> Union[SimpleProgress[None], tqdm]: ... |
| 113 | +) -> tqdm: ... |
205 | 114 |
|
206 | 115 |
|
207 | 116 | @typing.overload
|
208 | 117 | def progress(
|
209 | 118 | iterable: Iterable[IterableType],
|
210 | 119 | desc: Optional[str] = None,
|
211 | 120 | total: Optional[int] = None,
|
212 |
| - use_tqdm: bool = True, |
213 | 121 | file: Optional[TextIO] = None,
|
214 | 122 | mininterval: float = 0.5,
|
215 | 123 | **kwargs: object,
|
216 |
| -) -> Union[SimpleProgress[IterableType], tqdm]: ... |
| 124 | +) -> tqdm: ... |
217 | 125 |
|
218 | 126 |
|
219 | 127 | def progress(
|
220 | 128 | iterable: Optional[Iterable[IterableType]] = None,
|
221 | 129 | desc: Optional[str] = None,
|
222 | 130 | total: Optional[int] = None,
|
223 |
| - use_tqdm: bool = True, |
224 | 131 | file: Optional[TextIO] = None,
|
225 | 132 | mininterval: float = 0.5,
|
226 | 133 | **kwargs: object,
|
227 |
| -) -> Union[SimpleProgress[IterableType], tqdm]: |
228 |
| - # Try to use tqdm is possible. Fall back to simple progress print |
229 |
| - if tqdm and use_tqdm: |
230 |
| - return tqdm( |
231 |
| - iterable, |
232 |
| - desc=desc, |
233 |
| - total=total, |
234 |
| - file=file, |
235 |
| - mininterval=mininterval, |
236 |
| - **kwargs, |
237 |
| - ) |
238 |
| - else: |
239 |
| - if not tqdm and use_tqdm: |
240 |
| - warnings.warn( |
241 |
| - "Tried to show progress with tqdm " |
242 |
| - "but tqdm is not installed. " |
243 |
| - "Fall back to simply print out the progress.", |
244 |
| - stacklevel=1, |
245 |
| - ) |
246 |
| - return SimpleProgress( |
247 |
| - iterable, desc=desc, total=total, file=file, mininterval=mininterval |
248 |
| - ) |
| 134 | +) -> tqdm: |
| 135 | + return tqdm( |
| 136 | + iterable, |
| 137 | + desc=desc, |
| 138 | + total=total, |
| 139 | + file=file, |
| 140 | + mininterval=mininterval, |
| 141 | + **kwargs, |
| 142 | + ) |
0 commit comments