@@ -13,8 +13,8 @@ def filter_lines(text):
1313 return '' .join (lines )
1414
1515
16- def sort_columns (match ):
17- """Sort column lines in a schema dump alphabetically"""
16+ def sort_create_table_columns (match ):
17+ """Sort column lines in a CREATE TABLE block alphabetically"""
1818 lines = [
1919 line .strip ().rstrip (',' )
2020 for line in match .group (1 ).split ('\n ' )
@@ -23,5 +23,32 @@ def sort_columns(match):
2323 return '(\n ' + ',\n ' .join (sorted (lines )) + '\n )'
2424
2525
26+ def sort_copy_columns (text ):
27+ """Sort columns in COPY blocks and reorder the data rows accordingly"""
28+ result = []
29+ lines = text .split ('\n ' )
30+ i = 0
31+ while i < len (lines ):
32+ copy_match = re .match (r'(COPY \S+ \()([^)]+)(\) FROM stdin;)' , lines [i ])
33+ if copy_match :
34+ cols = [c .strip () for c in copy_match .group (2 ).split (',' )]
35+ sorted_indices = sorted (range (len (cols )), key = lambda j : cols [j ])
36+ sorted_cols = [cols [j ] for j in sorted_indices ]
37+ result .append (copy_match .group (1 ) + ', ' .join (sorted_cols ) + copy_match .group (3 ))
38+ i += 1
39+ while i < len (lines ) and lines [i ] != '\\ .' :
40+ values = lines [i ].split ('\t ' )
41+ result .append ('\t ' .join (values [j ] for j in sorted_indices ))
42+ i += 1
43+ if i < len (lines ):
44+ result .append (lines [i ]) # \.
45+ else :
46+ result .append (lines [i ])
47+ i += 1
48+ return '\n ' .join (result )
49+
50+
2651text = filter_lines (sys .stdin .read ())
27- sys .stdout .write (re .sub (r'\(((?:\n [^\n]+)+)\n\)' , sort_columns , text ))
52+ text = re .sub (r'\(((?:\n [^\n]+)+)\n\)' , sort_create_table_columns , text )
53+ text = sort_copy_columns (text )
54+ sys .stdout .write (text )
0 commit comments