Skip to content

Commit c89339d

Browse files
committed
update jsrt2
1 parent 77559e0 commit c89339d

3 files changed

Lines changed: 15 additions & 8 deletions

File tree

segmentation/JSRT2/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
In this example, we show how to use a customized CNN and a customized loss function to segment the heart from X-Ray images. The configurations are the same as those in the `JSRT` example except the network structure and loss function.
77

8-
The customized CNN is detailed in `my_net2d.py`, which is a modification of the 2D UNet. In this new network, we use a residual connection in each block. The customized loss is detailed in `my_loss.py`, where we define a focal dice loss.
8+
The customized CNN is detailed in `my_net2d.py`, which is a modification of the 2D UNet. In this new network, we use a residual connection in each block. The customized loss is detailed in `my_loss.py`, where we define a focal dice loss named as MyFocalDiceLoss. We use it `MyFocalDiceLoss + CrossEntropyLoss` to train the custermized network.
99

10-
We also write a customized main function in `jsrt_net_run.py` so that we can combine NetRunAgent from PyMIC with our customized CNN and loss function.
10+
We also write a customized main function in `jsrt_net_run.py` so that we can combine SegmentationAgent from PyMIC with our customized CNN and loss function.
1111

1212
## Data and preprocessing
1313
1. Data preprocessing is the same as that in the the `JSRT` example. Please follow that example for details.
@@ -37,4 +37,4 @@ python net_run_jsrt.py test config/train_test.cfg
3737
pymic_evaluate_seg config/evaluation.cfg
3838
```
3939

40-
The obtained dice score by default setting should be close to 93.90%.
40+
The obtained dice score by default setting should be close to 94.35%.

segmentation/JSRT2/config/train_test.cfg

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ bilinear = True
5656
# list of gpus
5757
gpus = [0]
5858

59-
loss_type = MyFocalDiceLoss
59+
loss_type = [MyFocalDiceLoss, CrossEntropyLoss]
60+
loss_weight = [1.0, 1.0]
6061
MyFocalDiceLoss_Enable_Pixel_Weight = False
6162
MyFocalDiceLoss_Enable_Class_Weight = True
6263
MyFocalDiceLoss_beta = 1.5
6364
class_weight = [0.2, 1.0]
6465

66+
CrossEntropyLoss_Enable_Pixel_Weight = False
67+
CrossEntropyLoss_Enable_Class_Weight = False
68+
6569
# for optimizers
6670
optimizer = Adam
6771
learning_rate = 1e-3

segmentation/JSRT2/net_run_jsrt.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,27 @@
44
import sys
55
from pymic.util.parse_config import parse_config
66
from pymic.net_run.agent_seg import SegmentationAgent
7+
from pymic.loss.loss_dict_seg import SegLossDict
78
from my_net2d import MyUNet2D
89
from my_loss import MyFocalDiceLoss
910

11+
loss_dict = {'MyFocalDiceLoss':MyFocalDiceLoss}
12+
loss_dict.update(SegLossDict)
13+
1014
def main():
1115
if(len(sys.argv) < 3):
1216
print('Number of arguments should be 3. e.g.')
13-
print(' python train_infer.py train config.cfg')
17+
print(' python net_run_jsrt.py train config.cfg')
1418
exit()
1519
stage = str(sys.argv[1])
1620
cfg_file = str(sys.argv[2])
1721
config = parse_config(cfg_file)
1822

23+
agent = SegmentationAgent(config, stage)
1924
# use custormized CNN and loss function
2025
mynet = MyUNet2D(config['network'])
21-
myloss = MyFocalDiceLoss(config['training'])
22-
agent = SegmentationAgent(config, stage)
2326
agent.set_network(mynet)
24-
agent.set_loss(myloss)
27+
agent.set_loss_dict(loss_dict)
2528
agent.run()
2629

2730
if __name__ == "__main__":

0 commit comments

Comments
 (0)