11package postgresql
22
33import (
4+ "errors"
45 "fmt"
56 "io"
67 "io/ioutil"
@@ -25,6 +26,9 @@ func stringSlice(list nodes.List) []string {
2526func parseTypeName (node nodes.Node ) (* ast.TypeName , error ) {
2627 switch n := node .(type ) {
2728
29+ case nodes.TypeName :
30+ return parseTypeName (n .Names )
31+
2832 case nodes.List :
2933 parts := stringSlice (n )
3034 switch len (parts ) {
@@ -42,7 +46,7 @@ func parseTypeName(node nodes.Node) (*ast.TypeName, error) {
4246 }
4347
4448 default :
45- return nil , fmt .Errorf ("unexpected node type: %T" , n )
49+ return nil , fmt .Errorf ("parseTypeName: unexpected node type: %T" , n )
4650 }
4751}
4852
@@ -85,7 +89,32 @@ func parseTableName(node nodes.Node) (*ast.TableName, error) {
8589 return & name , nil
8690
8791 default :
88- return nil , fmt .Errorf ("unexpected node type: %T" , n )
92+ return nil , fmt .Errorf ("parseTableName: unexpected node type: %T" , n )
93+ }
94+ }
95+
96+ func parseColName (node nodes.Node ) (* ast.ColumnRef , * ast.TableName , error ) {
97+ switch n := node .(type ) {
98+ case nodes.List :
99+ parts := stringSlice (n )
100+ var tbl * ast.TableName
101+ var ref * ast.ColumnRef
102+ switch len (parts ) {
103+ case 2 :
104+ tbl = & ast.TableName {Name : parts [0 ]}
105+ ref = & ast.ColumnRef {Name : parts [1 ]}
106+ case 3 :
107+ tbl = & ast.TableName {Schema : parts [0 ], Name : parts [1 ]}
108+ ref = & ast.ColumnRef {Name : parts [2 ]}
109+ case 4 :
110+ tbl = & ast.TableName {Catalog : parts [0 ], Schema : parts [1 ], Name : parts [2 ]}
111+ ref = & ast.ColumnRef {Name : parts [3 ]}
112+ default :
113+ return nil , nil , fmt .Errorf ("column specifier %q is not the proper format, expected '[catalog.][schema.]colname.tablename'" , strings .Join (parts , "." ))
114+ }
115+ return ref , tbl , nil
116+ default :
117+ return nil , nil , fmt .Errorf ("parseColName: unexpected node type: %T" , n )
89118 }
90119}
91120
@@ -100,6 +129,8 @@ func NewParser() *Parser {
100129type Parser struct {
101130}
102131
132+ var errSkip = errors .New ("skip stmt" )
133+
103134func (p * Parser ) Parse (r io.Reader ) ([]ast.Statement , error ) {
104135 contents , err := ioutil .ReadAll (r )
105136 if err != nil {
@@ -117,14 +148,18 @@ func (p *Parser) Parse(r io.Reader) ([]ast.Statement, error) {
117148 return nil , fmt .Errorf ("expected RawStmt; got %T" , stmt )
118149 }
119150 n , err := translate (raw .Stmt )
151+ if err == errSkip {
152+ continue
153+ }
120154 if err != nil {
121155 return nil , err
122156 }
123- if n != nil {
124- stmts = append (stmts , ast.Statement {
125- Raw : & ast.RawStmt {Stmt : n },
126- })
157+ if n == nil {
158+ return nil , fmt .Errorf ("unexpected nil node" )
127159 }
160+ stmts = append (stmts , ast.Statement {
161+ Raw : & ast.RawStmt {Stmt : n },
162+ })
128163 }
129164 return stmts , nil
130165}
@@ -183,6 +218,54 @@ func translate(node nodes.Node) (ast.Node, error) {
183218 }
184219 return at , nil
185220
221+ case nodes.CommentStmt :
222+ switch n .Objtype {
223+
224+ case nodes .OBJECT_COLUMN :
225+ col , tbl , err := parseColName (n .Object )
226+ if err != nil {
227+ return nil , fmt .Errorf ("COMMENT ON COLUMN: %w" , err )
228+ }
229+ return & ast.CommentOnColumnStmt {
230+ Col : col ,
231+ Table : tbl ,
232+ Comment : n .Comment ,
233+ }, nil
234+
235+ case nodes .OBJECT_SCHEMA :
236+ o , ok := n .Object .(nodes.String )
237+ if ! ok {
238+ return nil , fmt .Errorf ("COMMENT ON SCHEMA: unexpected node type: %T" , n .Object )
239+ }
240+ return & ast.CommentOnSchemaStmt {
241+ Schema : & ast.String {Str : o .Str },
242+ Comment : n .Comment ,
243+ }, nil
244+
245+ case nodes .OBJECT_TABLE :
246+ name , err := parseTableName (n .Object )
247+ if err != nil {
248+ return nil , fmt .Errorf ("COMMENT ON TABLE: %w" , err )
249+ }
250+ return & ast.CommentOnTableStmt {
251+ Table : name ,
252+ Comment : n .Comment ,
253+ }, nil
254+
255+ case nodes .OBJECT_TYPE :
256+ name , err := parseTypeName (n .Object )
257+ if err != nil {
258+ return nil , err
259+ }
260+ return & ast.CommentOnTypeStmt {
261+ Type : name ,
262+ Comment : n .Comment ,
263+ }, nil
264+
265+ }
266+
267+ return nil , errSkip
268+
186269 case nodes.CreateStmt :
187270 name , err := parseTableName (* n .Relation )
188271 if err != nil {
@@ -259,9 +342,9 @@ func translate(node nodes.Node) (ast.Node, error) {
259342 return drop , nil
260343
261344 }
262- return nil , nil
345+ return nil , errSkip
263346
264347 default :
265- return nil , nil
348+ return nil , errSkip
266349 }
267350}
0 commit comments