1import csv
2import re
3import textwrap
4
5from . import NOT_SET, strutil, fsutil
6
7
8EMPTY = '-'
9UNKNOWN = '???'
10
11
12def parse_markers(markers, default=None):
13    if markers is NOT_SET:
14        return default
15    if not markers:
16        return None
17    if type(markers) is not str:
18        return markers
19    if markers == markers[0] * len(markers):
20        return [markers]
21    return list(markers)
22
23
24def fix_row(row, **markers):
25    if isinstance(row, str):
26        raise NotImplementedError(row)
27    empty = parse_markers(markers.pop('empty', ('-',)))
28    unknown = parse_markers(markers.pop('unknown', ('???',)))
29    row = (val if val else None for val in row)
30    if not empty:
31        if unknown:
32            row = (UNKNOWN if val in unknown else val for val in row)
33    elif not unknown:
34        row = (EMPTY if val in empty else val for val in row)
35    else:
36        row = (EMPTY if val in empty else (UNKNOWN if val in unknown else val)
37               for val in row)
38    return tuple(row)
39
40
41def _fix_read_default(row):
42    for value in row:
43        yield value.strip()
44
45
46def _fix_write_default(row, empty=''):
47    for value in row:
48        yield empty if value is None else str(value)
49
50
51def _normalize_fix_read(fix):
52    if fix is None:
53        fix = ''
54    if callable(fix):
55        def fix_row(row):
56            values = fix(row)
57            return _fix_read_default(values)
58    elif isinstance(fix, str):
59        def fix_row(row):
60            values = _fix_read_default(row)
61            return (None if v == fix else v
62                    for v in values)
63    else:
64        raise NotImplementedError(fix)
65    return fix_row
66
67
68def _normalize_fix_write(fix, empty=''):
69    if fix is None:
70        fix = empty
71    if callable(fix):
72        def fix_row(row):
73            values = fix(row)
74            return _fix_write_default(values, empty)
75    elif isinstance(fix, str):
76        def fix_row(row):
77            return _fix_write_default(row, fix)
78    else:
79        raise NotImplementedError(fix)
80    return fix_row
81
82
83def read_table(infile, header, *,
84               sep='\t',
85               fix=None,
86               _open=open,
87               _get_reader=csv.reader,
88               ):
89    """Yield each row of the given ???-separated (e.g. tab) file."""
90    if isinstance(infile, str):
91        with _open(infile, newline='') as infile:
92            yield from read_table(
93                infile,
94                header,
95                sep=sep,
96                fix=fix,
97                _open=_open,
98                _get_reader=_get_reader,
99            )
100            return
101    lines = strutil._iter_significant_lines(infile)
102
103    # Validate the header.
104    if not isinstance(header, str):
105        header = sep.join(header)
106    try:
107        actualheader = next(lines).strip()
108    except StopIteration:
109        actualheader = ''
110    if actualheader != header:
111        raise ValueError(f'bad header {actualheader!r}')
112
113    fix_row = _normalize_fix_read(fix)
114    for row in _get_reader(lines, delimiter=sep or '\t'):
115        yield tuple(fix_row(row))
116
117
118def write_table(outfile, header, rows, *,
119                sep='\t',
120                fix=None,
121                backup=True,
122                _open=open,
123                _get_writer=csv.writer,
124                ):
125    """Write each of the rows to the given ???-separated (e.g. tab) file."""
126    if backup:
127        fsutil.create_backup(outfile, backup)
128    if isinstance(outfile, str):
129        with _open(outfile, 'w', newline='') as outfile:
130            return write_table(
131                outfile,
132                header,
133                rows,
134                sep=sep,
135                fix=fix,
136                backup=backup,
137                _open=_open,
138                _get_writer=_get_writer,
139            )
140
141    if isinstance(header, str):
142        header = header.split(sep or '\t')
143    fix_row = _normalize_fix_write(fix)
144    writer = _get_writer(outfile, delimiter=sep or '\t')
145    writer.writerow(header)
146    for row in rows:
147        writer.writerow(
148            tuple(fix_row(row))
149        )
150
151
152def parse_table(entries, sep, header=None, rawsep=None, *,
153                default=NOT_SET,
154                strict=True,
155                ):
156    header, sep = _normalize_table_file_props(header, sep)
157    if not sep:
158        raise ValueError('missing "sep"')
159
160    ncols = None
161    if header:
162        if strict:
163            ncols = len(header.split(sep))
164        cur_file = None
165    for line, filename in strutil.parse_entries(entries, ignoresep=sep):
166        _sep = sep
167        if filename:
168            if header and cur_file != filename:
169                cur_file = filename
170                # Skip the first line if it's the header.
171                if line.strip() == header:
172                    continue
173                else:
174                    # We expected the header.
175                    raise NotImplementedError((header, line))
176        elif rawsep and sep not in line:
177            _sep = rawsep
178
179        row = _parse_row(line, _sep, ncols, default)
180        if strict and not ncols:
181            ncols = len(row)
182        yield row, filename
183
184
185def parse_row(line, sep, *, ncols=None, default=NOT_SET):
186    if not sep:
187        raise ValueError('missing "sep"')
188    return _parse_row(line, sep, ncols, default)
189
190
191def _parse_row(line, sep, ncols, default):
192    row = tuple(v.strip() for v in line.split(sep))
193    if (ncols or 0) > 0:
194        diff = ncols - len(row)
195        if diff:
196            if default is NOT_SET or diff < 0:
197                raise Exception(f'bad row (expected {ncols} columns, got {row!r})')
198            row += (default,) * diff
199    return row
200
201
202def _normalize_table_file_props(header, sep):
203    if not header:
204        return None, sep
205
206    if not isinstance(header, str):
207        if not sep:
208            raise NotImplementedError(header)
209        header = sep.join(header)
210    elif not sep:
211        for sep in ('\t', ',', ' '):
212            if sep in header:
213                break
214        else:
215            sep = None
216    return header, sep
217
218
219##################################
220# stdout tables
221
222WIDTH = 20
223
224
225def resolve_columns(specs):
226    if isinstance(specs, str):
227        specs = specs.replace(',', ' ').strip().split()
228    return _resolve_colspecs(specs)
229
230
231def build_table(specs, *, sep=' ', defaultwidth=None):
232    columns = resolve_columns(specs)
233    return _build_table(columns, sep=sep, defaultwidth=defaultwidth)
234
235
236_COLSPEC_RE = re.compile(textwrap.dedent(r'''
237    ^
238    (?:
239        \[
240        (
241            (?: [^\s\]] [^\]]* )?
242            [^\s\]]
243        )  # <label>
244        ]
245    )?
246    ( \w+ )  # <field>
247    (?:
248        (?:
249            :
250            ( [<^>] )  # <align>
251            ( \d+ )  # <width1>
252        )
253        |
254        (?:
255            (?:
256                :
257                ( \d+ )  # <width2>
258            )?
259            (?:
260                :
261                ( .*? )  # <fmt>
262            )?
263        )
264    )?
265    $
266'''), re.VERBOSE)
267
268
269def _parse_fmt(fmt):
270    if fmt.startswith(tuple('<^>')):
271        align = fmt[0]
272        width = fmt[1:]
273        if width.isdigit():
274            return int(width), align
275    return None, None
276
277
278def _parse_colspec(raw):
279    m = _COLSPEC_RE.match(raw)
280    if not m:
281        return None
282    label, field, align, width1, width2, fmt = m.groups()
283    if not label:
284        label = field
285    if width1:
286        width = None
287        fmt = f'{align}{width1}'
288    elif width2:
289        width = int(width2)
290        if fmt:
291            _width, _ = _parse_fmt(fmt)
292            if _width == width:
293                width = None
294    else:
295        width = None
296    return field, label, width, fmt
297
298
299def _normalize_colspec(spec):
300    if len(spec) == 1:
301        raw, = spec
302        return _resolve_column(raw)
303
304    if len(spec) == 4:
305        label, field, width, fmt = spec
306        if width:
307            fmt = f'{width}:{fmt}' if fmt else width
308    elif len(raw) == 3:
309        label, field, fmt = spec
310        if not field:
311            label, field = None, label
312        elif not isinstance(field, str) or not field.isidentifier():
313            fmt = f'{field}:{fmt}' if fmt else field
314            label, field = None, label
315    elif len(raw) == 2:
316        label = None
317        field, fmt = raw
318        if not field:
319            field, fmt = fmt, None
320        elif not field.isidentifier() or fmt.isidentifier():
321            label, field = field, fmt
322    else:
323        raise NotImplementedError
324
325    fmt = f':{fmt}' if fmt else ''
326    if label:
327        return _parse_colspec(f'[{label}]{field}{fmt}')
328    else:
329        return _parse_colspec(f'{field}{fmt}')
330
331
332def _resolve_colspec(raw):
333    if isinstance(raw, str):
334        spec = _parse_colspec(raw)
335    else:
336        spec = _normalize_colspec(raw)
337    if spec is None:
338        raise ValueError(f'unsupported column spec {raw!r}')
339    return spec
340
341
342def _resolve_colspecs(columns):
343    parsed = []
344    for raw in columns:
345        column = _resolve_colspec(raw)
346        parsed.append(column)
347    return parsed
348
349
350def _resolve_width(spec, defaultwidth):
351    _, label, width, fmt = spec
352    if width:
353        if not isinstance(width, int):
354            raise NotImplementedError
355        return width
356    elif width and fmt:
357        width, _ = _parse_fmt(fmt)
358        if width:
359            return width
360
361    if not defaultwidth:
362        return WIDTH
363    elif not hasattr(defaultwidth, 'get'):
364        return defaultwidth or WIDTH
365
366    defaultwidths = defaultwidth
367    defaultwidth = defaultwidths.get(None) or WIDTH
368    return defaultwidths.get(label) or defaultwidth
369
370
371def _build_table(columns, *, sep=' ', defaultwidth=None):
372    header = []
373    div = []
374    rowfmt = []
375    for spec in columns:
376        label, field, _, colfmt = spec
377        width = _resolve_width(spec, defaultwidth)
378        if colfmt:
379            colfmt = f':{colfmt}'
380        else:
381            colfmt = f':{width}'
382
383        header.append(f' {{:^{width}}} '.format(label))
384        div.append('-' * (width + 2))
385        rowfmt.append(f' {{{field}{colfmt}}} ')
386    return (
387        sep.join(header),
388        sep.join(div),
389        sep.join(rowfmt),
390    )
391