1111// See the License for the specific language governing permissions and
1212// limitations under the License.
1313
14- package sql
14+ package executor
1515
1616import (
1717 "bufio"
@@ -40,35 +40,28 @@ import (
4040
4141var rePyDiagnosis = regexp .MustCompile ("sqlflow_submitter.tensorflow.diag.SQLFlowDiagnostic: (.*)" )
4242
43- // GetSubmitter returns a proper Submitter from configurations in environment variables.
44- func GetSubmitter (submitter string ) Submitter {
45- if submitter == "" {
46- submitter = os .Getenv ("SQLFLOW_submitter" )
47- }
48- switch submitter {
49- case "default" :
50- return & defaultSubmitter {}
51- case "pai" :
52- return & paiSubmitter {& defaultSubmitter {}}
53- case "alisa" :
54- return & alisaSubmitter {& defaultSubmitter {}}
55- // TODO(typhoonzero): add submitters like alps, elasticdl
56- default :
57- return & defaultSubmitter {}
58- }
59- }
60-
6143// Figures contains analyzed figures as strings
6244type Figures struct {
6345 Image string
6446 Text string
6547}
6648
67- // Submitter extends ir.Executor
68- type Submitter interface {
69- ir.Executor
70- Setup (* pipe.Writer , * database.DB , string , string , * pb.Session )
71- GetTrainStmtFromModel () bool
49+ // New returns a proper Submitter from configurations in environment variables.
50+ func New (executor string ) ir.Executor {
51+ if executor == "" {
52+ executor = os .Getenv ("SQLFLOW_submitter" )
53+ }
54+ switch executor {
55+ case "default" :
56+ return & pythonExecutor {}
57+ case "pai" :
58+ return & paiExecutor {& pythonExecutor {}}
59+ case "alisa" :
60+ return & alisaExecutor {& pythonExecutor {}}
61+ // TODO(typhoonzero): add executor like alps, elasticdl
62+ default :
63+ return & pythonExecutor {}
64+ }
7265}
7366
7467type logChanWriter struct {
@@ -110,20 +103,20 @@ func (cw *logChanWriter) Close() {
110103 }
111104}
112105
113- type defaultSubmitter struct {
106+ type pythonExecutor struct {
114107 Writer * pipe.Writer
115108 Db * database.DB
116109 ModelDir string
117110 Cwd string
118111 Session * pb.Session
119112}
120113
121- func (s * defaultSubmitter ) Setup (w * pipe.Writer , db * database.DB , modelDir string , cwd string , session * pb.Session ) {
114+ func (s * pythonExecutor ) Setup (w * pipe.Writer , db * database.DB , modelDir string , cwd string , session * pb.Session ) {
122115 // cwd is used to store train scripts and save output models.
123116 s .Writer , s .Db , s .ModelDir , s .Cwd , s .Session = w , db , modelDir , cwd , session
124117}
125118
126- func (s * defaultSubmitter ) SaveModel (cl * ir.TrainStmt ) error {
119+ func (s * pythonExecutor ) SaveModel (cl * ir.TrainStmt ) error {
127120 m := model .New (s .Cwd , cl .OriginalSQL )
128121 modelURI := cl .Into
129122 if s .ModelDir != "" {
@@ -132,7 +125,7 @@ func (s *defaultSubmitter) SaveModel(cl *ir.TrainStmt) error {
132125 return m .Save (modelURI , cl , s .Session )
133126}
134127
135- func (s * defaultSubmitter ) runCommand (program string , logStderr bool ) error {
128+ func (s * pythonExecutor ) runCommand (program string , logStderr bool ) error {
136129 cw := & logChanWriter {wr : s .Writer }
137130 defer cw .Close ()
138131 cmd := sqlflowCmd (s .Cwd , s .Db .DriverName )
@@ -159,11 +152,11 @@ func (s *defaultSubmitter) runCommand(program string, logStderr bool) error {
159152 return nil
160153}
161154
162- func (s * defaultSubmitter ) ExecuteQuery (sql * ir.NormalStmt ) error {
163- return runNormalStmt (s .Writer , string (* sql ), s .Db )
155+ func (s * pythonExecutor ) ExecuteQuery (stmt * ir.NormalStmt ) error {
156+ return runNormalStmt (s .Writer , string (* stmt ), s .Db )
164157}
165158
166- func (s * defaultSubmitter ) ExecuteTrain (cl * ir.TrainStmt ) (e error ) {
159+ func (s * pythonExecutor ) ExecuteTrain (cl * ir.TrainStmt ) (e error ) {
167160 var code string
168161 if isXGBoostModel (cl .Estimator ) {
169162 if code , e = xgboost .Train (cl , s .Session ); e != nil {
@@ -180,9 +173,9 @@ func (s *defaultSubmitter) ExecuteTrain(cl *ir.TrainStmt) (e error) {
180173 return s .SaveModel (cl )
181174}
182175
183- func (s * defaultSubmitter ) ExecutePredict (cl * ir.PredictStmt ) (e error ) {
176+ func (s * pythonExecutor ) ExecutePredict (cl * ir.PredictStmt ) (e error ) {
184177 // NOTE(typhoonzero): model is already loaded under s.Cwd
185- if e = createPredictionTableFromIR (cl , s .Db , s .Session ); e != nil {
178+ if e = createPredictionResultTable (cl , s .Db , s .Session ); e != nil {
186179 return e
187180 }
188181
@@ -199,7 +192,7 @@ func (s *defaultSubmitter) ExecutePredict(cl *ir.PredictStmt) (e error) {
199192 return s .runCommand (code , false )
200193}
201194
202- func (s * defaultSubmitter ) ExecuteExplain (cl * ir.ExplainStmt ) error {
195+ func (s * pythonExecutor ) ExecuteExplain (cl * ir.ExplainStmt ) error {
203196 // NOTE(typhoonzero): model is already loaded under s.Cwd
204197 var code string
205198 var err error
@@ -239,7 +232,7 @@ func (s *defaultSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
239232 return nil
240233}
241234
242- func (s * defaultSubmitter ) ExecuteEvaluate (cl * ir.EvaluateStmt ) error {
235+ func (s * pythonExecutor ) ExecuteEvaluate (cl * ir.EvaluateStmt ) error {
243236 // NOTE(typhoonzero): model is already loaded under s.Cwd
244237 var code string
245238 var err error
@@ -280,7 +273,7 @@ func (s *defaultSubmitter) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
280273 return nil
281274}
282275
283- func generateOptFlowOptimizeCodeAndExecute (cl * ir.OptimizeStmt , submitter * defaultSubmitter , session * pb.Session , cwd string , dbName string , tableName string , isPai bool ) error {
276+ func generateOptFlowOptimizeCodeAndExecute (cl * ir.OptimizeStmt , submitter * pythonExecutor , session * pb.Session , cwd string , dbName string , tableName string , isPai bool ) error {
284277 // Generate optimization code
285278 runnerFileName := "custom_optimize_runner"
286279 runnerCode , submitCode , err := optimize .GenerateOptFlowOptimizeCode (cl , session , dbName , tableName ,
@@ -311,7 +304,7 @@ func generateOptFlowOptimizeCodeAndExecute(cl *ir.OptimizeStmt, submitter *defau
311304 return nil
312305}
313306
314- func (s * defaultSubmitter ) ExecuteOptimize (cl * ir.OptimizeStmt ) error {
307+ func (s * pythonExecutor ) ExecuteOptimize (cl * ir.OptimizeStmt ) error {
315308 // TODO(sneaxiy): to be implemented
316309 return fmt .Errorf ("ExecuteOptimize is not supported in default submitter" )
317310}
@@ -359,9 +352,9 @@ func readExplainResult(target string) (string, error) {
359352 return fmt .Sprintf ("<div align='center'><img src='data:image/png;base64,%s' /></div>" , img ), nil
360353}
361354
362- func (s * defaultSubmitter ) GetTrainStmtFromModel () bool { return true }
355+ func (s * pythonExecutor ) GetTrainStmtFromModel () bool { return true }
363356
364- func (s * defaultSubmitter ) ExecuteShowTrain (showTrain * ir.ShowTrainStmt ) error {
357+ func (s * pythonExecutor ) ExecuteShowTrain (showTrain * ir.ShowTrainStmt ) error {
365358 model , err := model .Load (showTrain .ModelName , "" , s .Db )
366359 if err != nil {
367360 s .Writer .Write ("Load model meta " + showTrain .ModelName + " failed." )
0 commit comments