Skip to content

Commit c59d660

Browse files
authored
SQLFlow cli support model/repo list (#2750)
* refine pai code structure * remove unused function * move tensorflow files to a sub-package * add docstring * Add KMeans and RF PAI submitter * use model type defined in runtim.model * fix string compare * Evaluate for PAI * modify code * update cli design to support model zoo list * update doc * Add list model zoo option to cli
1 parent c5300f1 commit c59d660

5 files changed

Lines changed: 138 additions & 1 deletion

File tree

go/cmd/sqlflow/main.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Usage:
6262
sqlflow [options] release model [--force] [--local] [--desc=<desc>] <model_name> <version>
6363
sqlflow [options] delete repo <repo_name> <version>
6464
sqlflow [options] delete model <model_name> <version>
65+
sqlflow [options] list repo
66+
sqlflow [options] list model
6567
6668
Options:
6769
-v, --version print the version and exit
@@ -70,7 +72,9 @@ Options:
7072
--env-file=<file> config file in KEY=VAL format
7173
-s, --sqlflow-server=<addr> SQLFlow server address and port, e.g localhost:50051
7274
-m, --model-zoo-server=<addr> Model Zoo server address and port
73-
-d, --data-source=<data_source> data source to use when run or release model
75+
-d, --data-source=<data_source> data source to use when run or release model
76+
-u, --user=<user> Model Zoo user account
77+
-p, --password=<password> Model Zoo user password
7478
7579
Run Options:
7680
-e, --execute=<program> execute given program
@@ -98,6 +102,9 @@ type options struct {
98102
ModelName string `docopt:"<model_name>"`
99103
Version string `docopt:"<version>"`
100104
Description string `docopt:"--desc"`
105+
List bool
106+
User string
107+
Password string
101108
}
102109

103110
func isSpace(c byte) bool {
@@ -449,6 +456,10 @@ func processOptions(opts *options) {
449456
err = deleteModel(opts)
450457
case opts.Delete && opts.Repo:
451458
err = deleteRepo(opts)
459+
case opts.List && opts.Model:
460+
err = listModels(opts)
461+
case opts.List && opts.Repo:
462+
err = listRepos(opts)
452463
default:
453464
err = runSQLFlowClient(opts)
454465
}

go/cmd/sqlflow/main_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,16 @@ func TestParseArgument(t *testing.T) {
11151115
a.Equal("my_model", opts.ModelName)
11161116
a.Equal("v1.0", opts.Version)
11171117

1118+
opts, err = getOptions("list model")
1119+
a.NoError(err)
1120+
a.True(opts.List && opts.Model)
1121+
a.False(opts.Run || opts.Release || opts.Delete)
1122+
1123+
opts, err = getOptions("list repo")
1124+
a.NoError(err)
1125+
a.True(opts.List && opts.Repo)
1126+
a.False(opts.Run || opts.Release || opts.Delete)
1127+
11181128
// invalid args
11191129
opts, err = getOptions("kill model my_model")
11201130
a.Error(err)

go/cmd/sqlflow/model_zoo_client.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"sqlflow.org/sqlflow/go/database"
2727
"sqlflow.org/sqlflow/go/model"
2828
pb "sqlflow.org/sqlflow/go/proto"
29+
"sqlflow.org/sqlflow/go/step/tablewriter"
2930
"sqlflow.org/sqlflow/go/tar"
3031
)
3132

@@ -267,3 +268,93 @@ func deleteRepo(opts *options) error {
267268
}
268269
return nil
269270
}
271+
272+
func listModels(opts *options) error {
273+
if err := checkModelZooParam(opts); err != nil {
274+
return err
275+
}
276+
conn, err := getModelZooServerConn(opts)
277+
if err != nil {
278+
return err
279+
}
280+
defer conn.Close()
281+
282+
writer, err := tablewriter.Create("ascii", 1024, os.Stdout)
283+
if err != nil {
284+
return err
285+
}
286+
writer.SetHeader(map[string]interface{}{
287+
"columnNames": []string{
288+
"Name", "Tag", "ModelStoreUrl", "ImageUrl", "Description", "Metric"},
289+
})
290+
291+
start := int64(0)
292+
client := pb.NewModelZooServerClient(conn)
293+
for {
294+
req := &pb.ListModelRequest{
295+
// (TODO: lhw) add authentication information in request
296+
Start: start,
297+
Size: 100,
298+
}
299+
resp, err := client.ListModels(context.Background(), req)
300+
if err != nil {
301+
return err
302+
}
303+
if resp.Size <= 0 {
304+
break
305+
}
306+
for _, m := range resp.ModelList {
307+
writer.AppendRow([]interface{}{
308+
m.Name, m.Tag, m.ModelStoreUrl, m.ImageUrl, m.Description, m.Metric,
309+
})
310+
}
311+
start += resp.Size
312+
}
313+
writer.Flush()
314+
return nil
315+
}
316+
317+
func listRepos(opts *options) error {
318+
if err := checkModelZooParam(opts); err != nil {
319+
return err
320+
}
321+
conn, err := getModelZooServerConn(opts)
322+
if err != nil {
323+
return err
324+
}
325+
defer conn.Close()
326+
327+
writer, err := tablewriter.Create("ascii", 1024, os.Stdout)
328+
if err != nil {
329+
return err
330+
}
331+
writer.SetHeader(map[string]interface{}{
332+
"columnNames": []string{
333+
"ClassName", "ImageUrl", "Tag", "ArgDescs"},
334+
})
335+
336+
start := int64(0)
337+
client := pb.NewModelZooServerClient(conn)
338+
for {
339+
req := &pb.ListModelRequest{
340+
// (TODO: lhw) add authentication information in request
341+
Start: start,
342+
Size: 100,
343+
}
344+
resp, err := client.ListModelRepos(context.Background(), req)
345+
if err != nil {
346+
return err
347+
}
348+
if resp.Size <= 0 {
349+
break
350+
}
351+
for _, r := range resp.ModelDefList {
352+
writer.AppendRow([]interface{}{
353+
r.ClassName, r.ImageUrl, r.Tag, r.ArgDescs,
354+
})
355+
}
356+
start += int64(resp.Size)
357+
}
358+
writer.Flush()
359+
return nil
360+
}

go/cmd/sqlflow/model_zoo_client_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/stretchr/testify/assert"
2424
"sqlflow.org/sqlflow/go/database"
2525
"sqlflow.org/sqlflow/go/modelzooserver"
26+
"sqlflow.org/sqlflow/go/step"
2627
)
2728

2829
const modelZooServerPort = 50055
@@ -130,6 +131,26 @@ func CaseDeleteModel(t *testing.T) {
130131
a.NoError(deleteModel(opts))
131132
}
132133

134+
func caseListModels(t *testing.T) {
135+
a := assert.New(t)
136+
cmd := fmt.Sprintf("--model-zoo-server=localhost:%d list model", modelZooServerPort)
137+
opts, err := getOptions(cmd)
138+
a.NoError(err)
139+
out, err := step.GetStdout(func() error { listModels(opts); return nil })
140+
a.NoError(err)
141+
a.Contains(out, "iris.my_model")
142+
}
143+
144+
func caseListRepos(t *testing.T) {
145+
a := assert.New(t)
146+
cmd := fmt.Sprintf("--model-zoo-server=localhost:%d list repo", modelZooServerPort)
147+
opts, err := getOptions(cmd)
148+
a.NoError(err)
149+
out, err := step.GetStdout(func() error { listRepos(opts); return nil })
150+
a.NoError(err)
151+
a.Contains(out, "DNNClassifier")
152+
}
153+
133154
func TestModelZooOperation(t *testing.T) {
134155
a := assert.New(t)
135156
startTestModelZooServer()
@@ -142,6 +163,8 @@ func TestModelZooOperation(t *testing.T) {
142163
t.Run("caseTrainModel", caseTrainModel)
143164
t.Run("caseReleaseModel", caseReleaseModel)
144165
t.Run("caseReleaseModelLocal", caseReleaseModelLocal)
166+
t.Run("caseListModels", caseListModels)
167+
t.Run("caseListRepos", caseListRepos)
145168
t.Run("caseDeleteModel", CaseDeleteModel)
146169
t.Run("caseDeleteRepo", caseDeleteRepo)
147170
}

go/modelzooserver/modelzooserver.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ func (s *modelZooServer) ListModelRepos(ctx context.Context, req *pb.ListModelRe
140140
responseList.ModelDefList,
141141
perResp,
142142
)
143+
responseList.Size++
143144
}
144145
return responseList, nil
145146
}
@@ -187,6 +188,7 @@ LEFT JOIN %s AS c ON b.model_coll_id=c.id LIMIT %d OFFSET %d;`,
187188
trainedModelList.ModelList,
188189
perResp,
189190
)
191+
trainedModelList.Size++
190192
}
191193

192194
return trainedModelList, nil

0 commit comments

Comments
 (0)