@@ -15,21 +15,85 @@ package optimize
1515
1616import (
1717 "bytes"
18+ "encoding/json"
1819 "fmt"
20+ "sqlflow.org/sqlflow/pkg/attribute"
1921 "sqlflow.org/sqlflow/pkg/ir"
2022 pb "sqlflow.org/sqlflow/pkg/proto"
2123 "strings"
2224 "text/template"
2325)
2426
27+ func checkIsPositiveInteger (i interface {}, name string ) error {
28+ if v , ok := i .(int ); ! ok || v <= 0 {
29+ return fmt .Errorf ("%s should be positive integer" , name )
30+ }
31+ return nil
32+ }
33+
34+ // TODO(sneaxiy): polish attribute codes
35+ var attributeDictionary = attribute.Dictionary {
36+ "data.enable_slice" : {attribute .Bool , false , "Whether to enable data slicing" , nil },
37+ "data.batch_size" : {attribute .Int , - 1 , "Batch size when training" , nil },
38+ "worker.num" : {attribute .Int , 1 , "Worker number" , func (i interface {}) error {
39+ return checkIsPositiveInteger (i , "worker.num" )
40+ }},
41+ "worker.core" : {attribute .Int , 8 , "Worker core number" , func (i interface {}) error {
42+ return checkIsPositiveInteger (i , "worker.core" )
43+ }},
44+ "worker.memory" : {attribute .Int , 4096 , "Worker memory" , func (i interface {}) error {
45+ return checkIsPositiveInteger (i , "worker.memory" )
46+ }},
47+ "solver.*" : {attribute .Unknown , nil , "Solver options" , nil },
48+ }
49+
50+ // InitializeAttributes initialize attributes in optimize clause IR
51+ func InitializeAttributes (stmt * ir.OptimizeStmt ) error {
52+ attributeDictionary .FillDefaults (stmt .Attributes )
53+ err := attributeDictionary .Validate (stmt .Attributes )
54+ return err
55+ }
56+
2557// GenerateOptFlowOptimizeCode generates optimize codes for execution
2658// The returned value is (runnerProgramCode, submitProgramCode, error)
27- func GenerateOptFlowOptimizeCode (optimStmt * ir.OptimizeStmt , session * pb.Session , dbName , tableName , runnerModuleName string , isPai bool ) (string , string , error ) {
59+ func GenerateOptFlowOptimizeCode (optimStmt * ir.OptimizeStmt , session * pb.Session , dbName , tableName , runnerModuleName string ) (string , string , error ) {
60+ const (
61+ dataAttrPrefix = "data."
62+ solverAttrPrefix = "solver."
63+ workerAttrPrefix = "worker."
64+ )
65+
2866 resultTable := optimStmt .ResultTable
2967 if ! strings .Contains (resultTable , "." ) {
3068 resultTable = fmt .Sprintf ("%s.%s" , dbName , resultTable )
3169 }
3270
71+ attrs := make (map [string ]map [string ]interface {})
72+ for k , v := range optimStmt .Attributes {
73+ prefix := ""
74+ if strings .HasPrefix (k , dataAttrPrefix ) {
75+ prefix = dataAttrPrefix
76+ } else if strings .HasPrefix (k , solverAttrPrefix ) {
77+ prefix = solverAttrPrefix
78+ } else if strings .HasPrefix (k , workerAttrPrefix ) {
79+ prefix = workerAttrPrefix
80+ } else {
81+ return "" , "" , fmt .Errorf ("unrecognized attribute %s" , k )
82+ }
83+
84+ k = k [len (prefix ):]
85+ prefixKey := prefix [0 : len (prefix )- 1 ]
86+ if _ , ok := attrs [prefixKey ]; ! ok {
87+ attrs [prefixKey ] = make (map [string ]interface {})
88+ }
89+ attrs [prefixKey ][k ] = v
90+ }
91+
92+ attrJSON , err := json .Marshal (attrs )
93+ if err != nil {
94+ return "" , "" , err
95+ }
96+
3397 filler := optimizeFiller {
3498 UserID : session .UserId ,
3599 Variables : optimStmt .Variables ,
@@ -39,9 +103,9 @@ func GenerateOptFlowOptimizeCode(optimStmt *ir.OptimizeStmt, session *pb.Session
39103 Direction : optimStmt .Direction ,
40104 Constraints : optimStmt .Constraints ,
41105 Solver : optimStmt .Solver ,
106+ AttributeJSON : string (attrJSON ),
42107 TrainTable : fmt .Sprintf ("%s.%s" , dbName , tableName ),
43108 ResultTable : resultTable ,
44- IsPAI : isPai ,
45109 RunnerModule : runnerModuleName ,
46110 }
47111
0 commit comments