Skip to content

Commit e81457e

Browse files
authored
cleanup sql package, move submitters to executor pkg (#2481)
* refactor executor * update * update * update * delete unused cmd.go * fix ut * polish function and add some comments * update * fix ci * revert attribute check to ir_generator
1 parent befcd74 commit e81457e

23 files changed

Lines changed: 347 additions & 303 deletions

cmd/sqlflow/rpc_stream_renderer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import (
2020
"regexp"
2121
"time"
2222

23+
"sqlflow.org/sqlflow/pkg/executor"
2324
"sqlflow.org/sqlflow/pkg/proto"
24-
"sqlflow.org/sqlflow/pkg/sql"
2525
"sqlflow.org/sqlflow/pkg/step"
2626
"sqlflow.org/sqlflow/pkg/tablewriter"
2727
)
@@ -55,7 +55,7 @@ func render(ctx *renderContext, obj interface{}) error {
5555
case *proto.Response_Message:
5656
re := regexp.MustCompile(`<div.*?>.*</div>`)
5757
if re.MatchString(r.Message.Message) {
58-
renderObj = sql.Figures{
58+
renderObj = executor.Figures{
5959
Image: r.Message.Message,
6060
}
6161
} else {

pkg/database/testing.go

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ package database
1616
import (
1717
"fmt"
1818
"log"
19-
"os"
2019
"sync"
2120

2221
"sqlflow.org/gomaxcompute"
2322
"sqlflow.org/sqlflow/pkg/sql/testdata"
23+
"sqlflow.org/sqlflow/pkg/test"
2424

2525
"github.com/go-sql-driver/mysql"
2626
"sqlflow.org/sqlflow/pkg/proto"
@@ -52,7 +52,7 @@ func GetTestingDBSingleton() *DB {
5252
// order to do it, users might want to define TestMain and call
5353
// createTestingDB and defer db.Close in it.
5454
func createTestingDB() *DB {
55-
switch dbms := getEnv("SQLFLOW_TEST_DB", "mysql"); dbms {
55+
switch dbms := test.GetEnv("SQLFLOW_TEST_DB", "mysql"); dbms {
5656
case "mysql":
5757
return createTestingMySQLDB()
5858
case "hive":
@@ -65,20 +65,13 @@ func createTestingDB() *DB {
6565
return nil
6666
}
6767

68-
func getEnv(env, value string) string {
69-
if env := os.Getenv(env); len(env) != 0 {
70-
return env
71-
}
72-
return value
73-
}
74-
7568
// GetTestingMySQLConfig construct a MySQL config
7669
func GetTestingMySQLConfig() *mysql.Config {
7770
return &mysql.Config{
78-
User: getEnv("SQLFLOW_TEST_DB_MYSQL_USER", "root"),
79-
Passwd: getEnv("SQLFLOW_TEST_DB_MYSQL_PASSWD", "root"),
80-
Net: getEnv("SQLFLOW_TEST_DB_MYSQL_NET", "tcp"),
81-
Addr: getEnv("SQLFLOW_TEST_DB_MYSQL_ADDR", "127.0.0.1:3306"),
71+
User: test.GetEnv("SQLFLOW_TEST_DB_MYSQL_USER", "root"),
72+
Passwd: test.GetEnv("SQLFLOW_TEST_DB_MYSQL_PASSWD", "root"),
73+
Net: test.GetEnv("SQLFLOW_TEST_DB_MYSQL_NET", "tcp"),
74+
Addr: test.GetEnv("SQLFLOW_TEST_DB_MYSQL_ADDR", "127.0.0.1:3306"),
8275
AllowNativePasswords: true,
8376
}
8477
}
@@ -119,10 +112,10 @@ func createTestingHiveDB() *DB {
119112

120113
func testingMaxComputeConfig() *gomaxcompute.Config {
121114
return &gomaxcompute.Config{
122-
AccessID: getEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_AK", "test"),
123-
AccessKey: getEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_SK", "test"),
124-
Project: getEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_PROJECT", "test"),
125-
Endpoint: getEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_ENDPOINT", "http://service-maxcompute.com/api"),
115+
AccessID: test.GetEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_AK", "test"),
116+
AccessKey: test.GetEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_SK", "test"),
117+
Project: test.GetEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_PROJECT", "test"),
118+
Endpoint: test.GetEnv("SQLFLOW_TEST_DB_MAXCOMPUTE_ENDPOINT", "http://service-maxcompute.com/api"),
126119
}
127120
}
128121

pkg/database/testing_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ import (
1717
"testing"
1818

1919
"github.com/stretchr/testify/assert"
20+
"sqlflow.org/sqlflow/pkg/test"
2021
)
2122

2223
func TestDatabaseGetTestingDBSingleton(t *testing.T) {
2324
db := GetTestingDBSingleton()
2425
a := assert.New(t)
2526

26-
switch dbms := getEnv("SQLFLOW_TEST_DB", "mysql"); dbms {
27+
switch dbms := test.GetEnv("SQLFLOW_TEST_DB", "mysql"); dbms {
2728
case "mysql":
2829
a.Equal(GetTestingMySQLURL(), db.URL())
2930
case "hive":
Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package sql
14+
package executor
1515

1616
import (
1717
"bufio"
@@ -36,11 +36,11 @@ var resourceName = "job.tar.gz"
3636
var entryFile = "entry.py"
3737
var reOSS = regexp.MustCompile(`oss://([^/]+).*host=([^&]+)`)
3838

39-
type alisaSubmitter struct {
40-
*defaultSubmitter
39+
type alisaExecutor struct {
40+
*pythonExecutor
4141
}
4242

43-
func (s *alisaSubmitter) submitAlisaTask(submitCode, codeResourceURL, paramsResourceURL string) error {
43+
func (s *alisaExecutor) submitAlisaTask(submitCode, codeResourceURL, paramsResourceURL string) error {
4444
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
4545
if err != nil {
4646
return err
@@ -63,9 +63,8 @@ func (s *alisaSubmitter) submitAlisaTask(submitCode, codeResourceURL, paramsReso
6363

6464
}
6565

66-
func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
67-
ts.TmpTrainTable, ts.TmpValidateTable, e = createTempTrainAndValTable(ts.Select, ts.ValidationSelect, s.Session.DbConnStr)
68-
if e != nil {
66+
func (s *alisaExecutor) ExecuteTrain(ts *ir.TrainStmt) (e error) {
67+
if e = preExecuteTrainOnpPA(ts, s.Session); e != nil {
6968
return e
7069
}
7170
defer dropTmpTables([]string{ts.TmpTrainTable, ts.TmpValidateTable}, s.Session.DbConnStr)
@@ -105,15 +104,15 @@ func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
105104
return s.uploadResourceAndSubmitAlisaTask(code, requirements, paiCmd, ts.Estimator)
106105
}
107106

108-
func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
107+
func (s *alisaExecutor) ExecutePredict(ps *ir.PredictStmt) error {
109108
dbName, tableName, err := createTmpTableFromSelect(ps.Select, s.Session.DbConnStr)
110109
if err != nil {
111110
return err
112111
}
113112
ps.TmpPredictTable = strings.Join([]string{dbName, tableName}, ".")
114113
defer dropTmpTables([]string{ps.TmpPredictTable}, s.Session.DbConnStr)
115114

116-
if e := createPredictionTableFromIR(ps, s.Db, s.Session); e != nil {
115+
if e := createPredictionResultTable(ps, s.Db, s.Session); e != nil {
117116
return e
118117
}
119118

@@ -143,7 +142,7 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
143142
return s.uploadResourceAndSubmitAlisaTask(code, requirements, paiCmd, estimator)
144143
}
145144

146-
func (s *alisaSubmitter) uploadResourceAndSubmitAlisaTask(entryCode, requirements, alisaExecCode, estimator string) error {
145+
func (s *alisaExecutor) uploadResourceAndSubmitAlisaTask(entryCode, requirements, alisaExecCode, estimator string) error {
147146
// upload generated program to OSS and submit an Alisa task.
148147
ossCodeObjectName := randStringRunes(16)
149148
alisaBucket, e := getAlisaBucket()
@@ -163,7 +162,7 @@ func (s *alisaSubmitter) uploadResourceAndSubmitAlisaTask(entryCode, requirement
163162
return s.submitAlisaTask(alisaExecCode, codeResourceURL, paramResourceURL)
164163
}
165164

166-
func (s *alisaSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
165+
func (s *alisaExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
167166
dbName, tableName, err := createTmpTableFromSelect(cl.Select, s.Session.DbConnStr)
168167
if err != nil {
169168
return err
@@ -207,7 +206,7 @@ func (s *alisaSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
207206
return e
208207
}
209208

210-
func (s *alisaSubmitter) ExecuteEvaluate(es *ir.EvaluateStmt) error {
209+
func (s *alisaExecutor) ExecuteEvaluate(es *ir.EvaluateStmt) error {
211210
dbName, tableName, e := createTmpTableFromSelect(es.Select, s.Session.DbConnStr)
212211
if e != nil {
213212
return e
@@ -252,11 +251,11 @@ func (s *alisaSubmitter) ExecuteEvaluate(es *ir.EvaluateStmt) error {
252251
return s.uploadResourceAndSubmitAlisaTask(code, requirements, paiCmd, estimator)
253252
}
254253

255-
func (s *alisaSubmitter) ExecuteOptimize(es *ir.OptimizeStmt) error {
254+
func (s *alisaExecutor) ExecuteOptimize(es *ir.OptimizeStmt) error {
256255
return fmt.Errorf("ExecuteOptimize is not implemented in alisa submitter")
257256
}
258257

259-
func (s *alisaSubmitter) GetTrainStmtFromModel() bool { return false }
258+
func (s *alisaExecutor) GetTrainStmtFromModel() bool { return false }
260259

261260
func findPyModulePath(pyModuleName string) (string, error) {
262261
var b bytes.Buffer
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package sql
14+
package executor
1515

1616
import (
1717
"testing"
@@ -21,7 +21,7 @@ import (
2121

2222
func TestAlisaSubmitter(t *testing.T) {
2323
a := assert.New(t)
24-
_, ok := GetSubmitter("alisa").(*alisaSubmitter)
24+
_, ok := New("alisa").(*alisaExecutor)
2525
a.True(ok)
2626
}
2727

pkg/sql/cmd.go renamed to pkg/executor/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package sql
14+
package executor
1515

1616
import (
1717
"fmt"
Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package sql
14+
package executor
1515

1616
import (
1717
"bufio"
@@ -40,35 +40,28 @@ import (
4040

4141
var rePyDiagnosis = regexp.MustCompile("sqlflow_submitter.tensorflow.diag.SQLFlowDiagnostic: (.*)")
4242

43-
// GetSubmitter returns a proper Submitter from configurations in environment variables.
44-
func GetSubmitter(submitter string) Submitter {
45-
if submitter == "" {
46-
submitter = os.Getenv("SQLFLOW_submitter")
47-
}
48-
switch submitter {
49-
case "default":
50-
return &defaultSubmitter{}
51-
case "pai":
52-
return &paiSubmitter{&defaultSubmitter{}}
53-
case "alisa":
54-
return &alisaSubmitter{&defaultSubmitter{}}
55-
// TODO(typhoonzero): add submitters like alps, elasticdl
56-
default:
57-
return &defaultSubmitter{}
58-
}
59-
}
60-
6143
// Figures contains analyzed figures as strings
6244
type Figures struct {
6345
Image string
6446
Text string
6547
}
6648

67-
// Submitter extends ir.Executor
68-
type Submitter interface {
69-
ir.Executor
70-
Setup(*pipe.Writer, *database.DB, string, string, *pb.Session)
71-
GetTrainStmtFromModel() bool
49+
// New returns a proper Submitter from configurations in environment variables.
50+
func New(executor string) ir.Executor {
51+
if executor == "" {
52+
executor = os.Getenv("SQLFLOW_submitter")
53+
}
54+
switch executor {
55+
case "default":
56+
return &pythonExecutor{}
57+
case "pai":
58+
return &paiExecutor{&pythonExecutor{}}
59+
case "alisa":
60+
return &alisaExecutor{&pythonExecutor{}}
61+
// TODO(typhoonzero): add executor like alps, elasticdl
62+
default:
63+
return &pythonExecutor{}
64+
}
7265
}
7366

7467
type logChanWriter struct {
@@ -110,20 +103,20 @@ func (cw *logChanWriter) Close() {
110103
}
111104
}
112105

113-
type defaultSubmitter struct {
106+
type pythonExecutor struct {
114107
Writer *pipe.Writer
115108
Db *database.DB
116109
ModelDir string
117110
Cwd string
118111
Session *pb.Session
119112
}
120113

121-
func (s *defaultSubmitter) Setup(w *pipe.Writer, db *database.DB, modelDir string, cwd string, session *pb.Session) {
114+
func (s *pythonExecutor) Setup(w *pipe.Writer, db *database.DB, modelDir string, cwd string, session *pb.Session) {
122115
// cwd is used to store train scripts and save output models.
123116
s.Writer, s.Db, s.ModelDir, s.Cwd, s.Session = w, db, modelDir, cwd, session
124117
}
125118

126-
func (s *defaultSubmitter) SaveModel(cl *ir.TrainStmt) error {
119+
func (s *pythonExecutor) SaveModel(cl *ir.TrainStmt) error {
127120
m := model.New(s.Cwd, cl.OriginalSQL)
128121
modelURI := cl.Into
129122
if s.ModelDir != "" {
@@ -132,7 +125,7 @@ func (s *defaultSubmitter) SaveModel(cl *ir.TrainStmt) error {
132125
return m.Save(modelURI, cl, s.Session)
133126
}
134127

135-
func (s *defaultSubmitter) runCommand(program string, logStderr bool) error {
128+
func (s *pythonExecutor) runCommand(program string, logStderr bool) error {
136129
cw := &logChanWriter{wr: s.Writer}
137130
defer cw.Close()
138131
cmd := sqlflowCmd(s.Cwd, s.Db.DriverName)
@@ -159,11 +152,11 @@ func (s *defaultSubmitter) runCommand(program string, logStderr bool) error {
159152
return nil
160153
}
161154

162-
func (s *defaultSubmitter) ExecuteQuery(sql *ir.NormalStmt) error {
163-
return runNormalStmt(s.Writer, string(*sql), s.Db)
155+
func (s *pythonExecutor) ExecuteQuery(stmt *ir.NormalStmt) error {
156+
return runNormalStmt(s.Writer, string(*stmt), s.Db)
164157
}
165158

166-
func (s *defaultSubmitter) ExecuteTrain(cl *ir.TrainStmt) (e error) {
159+
func (s *pythonExecutor) ExecuteTrain(cl *ir.TrainStmt) (e error) {
167160
var code string
168161
if isXGBoostModel(cl.Estimator) {
169162
if code, e = xgboost.Train(cl, s.Session); e != nil {
@@ -180,9 +173,9 @@ func (s *defaultSubmitter) ExecuteTrain(cl *ir.TrainStmt) (e error) {
180173
return s.SaveModel(cl)
181174
}
182175

183-
func (s *defaultSubmitter) ExecutePredict(cl *ir.PredictStmt) (e error) {
176+
func (s *pythonExecutor) ExecutePredict(cl *ir.PredictStmt) (e error) {
184177
// NOTE(typhoonzero): model is already loaded under s.Cwd
185-
if e = createPredictionTableFromIR(cl, s.Db, s.Session); e != nil {
178+
if e = createPredictionResultTable(cl, s.Db, s.Session); e != nil {
186179
return e
187180
}
188181

@@ -199,7 +192,7 @@ func (s *defaultSubmitter) ExecutePredict(cl *ir.PredictStmt) (e error) {
199192
return s.runCommand(code, false)
200193
}
201194

202-
func (s *defaultSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
195+
func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
203196
// NOTE(typhoonzero): model is already loaded under s.Cwd
204197
var code string
205198
var err error
@@ -239,7 +232,7 @@ func (s *defaultSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
239232
return nil
240233
}
241234

242-
func (s *defaultSubmitter) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
235+
func (s *pythonExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
243236
// NOTE(typhoonzero): model is already loaded under s.Cwd
244237
var code string
245238
var err error
@@ -280,7 +273,7 @@ func (s *defaultSubmitter) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
280273
return nil
281274
}
282275

283-
func generateOptFlowOptimizeCodeAndExecute(cl *ir.OptimizeStmt, submitter *defaultSubmitter, session *pb.Session, cwd string, dbName string, tableName string, isPai bool) error {
276+
func generateOptFlowOptimizeCodeAndExecute(cl *ir.OptimizeStmt, submitter *pythonExecutor, session *pb.Session, cwd string, dbName string, tableName string, isPai bool) error {
284277
// Generate optimization code
285278
runnerFileName := "custom_optimize_runner"
286279
runnerCode, submitCode, err := optimize.GenerateOptFlowOptimizeCode(cl, session, dbName, tableName,
@@ -311,7 +304,7 @@ func generateOptFlowOptimizeCodeAndExecute(cl *ir.OptimizeStmt, submitter *defau
311304
return nil
312305
}
313306

314-
func (s *defaultSubmitter) ExecuteOptimize(cl *ir.OptimizeStmt) error {
307+
func (s *pythonExecutor) ExecuteOptimize(cl *ir.OptimizeStmt) error {
315308
// TODO(sneaxiy): to be implemented
316309
return fmt.Errorf("ExecuteOptimize is not supported in default submitter")
317310
}
@@ -359,9 +352,9 @@ func readExplainResult(target string) (string, error) {
359352
return fmt.Sprintf("<div align='center'><img src='data:image/png;base64,%s' /></div>", img), nil
360353
}
361354

362-
func (s *defaultSubmitter) GetTrainStmtFromModel() bool { return true }
355+
func (s *pythonExecutor) GetTrainStmtFromModel() bool { return true }
363356

364-
func (s *defaultSubmitter) ExecuteShowTrain(showTrain *ir.ShowTrainStmt) error {
357+
func (s *pythonExecutor) ExecuteShowTrain(showTrain *ir.ShowTrainStmt) error {
365358
model, err := model.Load(showTrain.ModelName, "", s.Db)
366359
if err != nil {
367360
s.Writer.Write("Load model meta " + showTrain.ModelName + " failed.")

0 commit comments

Comments
 (0)