classSoftCrossEntropyLoss(paddle.nn.Layer):""" Drop-in replacement for nn.CrossEntropyLoss with few additions: - Support of label smoothing """__constants__=["reduction","ignore_index","smooth_factor"]def__init__(self,reduction:str="mean",smooth_factor:float=0.0,ignore_index:Optional[int]=-100,dim=1,):super().__init__()self.smooth_factor=smooth_factorself.ignore_index=ignore_indexself.reduction=reductionself.dim=dimdefforward(self,input:paddle.Tensor,target:paddle.Tensor)->paddle.Tensor:log_prob=paddle.nn.functional.log_softmax(x=input,axis=self.dim)returnlabel_smoothed_nll_loss(log_prob,target,epsilon=self.smooth_factor,ignore_index=self.ignore_index,reduction=self.reduction,dim=self.dim,)
classDiceLoss(paddle.nn.Layer):""" Implementation of Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases """def__init__(self,mode:str="multiclass",classes:List[int]=None,log_loss=False,from_logits=True,smooth:float=0.0,ignore_index=None,eps=1e-07,):""" :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} :param classes: Optional list of classes that contribute in loss computation; By default, all channels are included. :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` :param from_logits: If True assumes input is raw logits :param smooth: :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) :param eps: Small epsilon for numerical stability """assertmodein{BINARY_MODE,MULTILABEL_MODE,MULTICLASS_MODE}super(DiceLoss,self).__init__()self.mode=modeifclassesisnotNone:assert(mode!=BINARY_MODE),"Masking classes is not supported with mode=binary"classes=to_tensor(classes,dtype="int64")self.classes=classesself.from_logits=from_logitsself.smooth=smoothself.eps=epsself.ignore_index=ignore_indexself.log_loss=log_lossdefforward(self,y_pred:paddle.Tensor,y_true:paddle.Tensor)->paddle.Tensor:""" :param y_pred: NxCxHxW :param y_true: NxHxW :return: scalar """asserty_true.shape[0]==y_pred.shape[0]ifself.from_logits:ifself.mode==MULTICLASS_MODE:y_pred=paddle.nn.functional.log_softmax(y_pred,axis=1).exp()else:y_pred=paddle.nn.functional.log_sigmoid(x=y_pred).exp()bs=y_true.shape[0]num_classes=y_pred.shape[1]dims=0,2ifself.mode==BINARY_MODE:y_true=y_true.view(bs,1,-1)y_pred=y_pred.view(bs,1,-1)ifself.ignore_indexisnotNone:mask=y_true!=self.ignore_indexy_pred=y_pred*paddle.cast(mask,dtype="float32")y_true=y_true*paddle.cast(mask,dtype="float32")ifself.mode==MULTICLASS_MODE:y_true=y_true.view(bs,-1)y_pred=y_pred.view(bs,num_classes,-1)ifself.ignore_indexisnotNone:ifself.ignore_indexisnotNone:mask=y_true!=self.ignore_indexmask=paddle.cast(mask,dtype="float32")y_pred=paddle.cast(y_pred*mask.unsqueeze(axis=1),dtype="float32")mask_float=paddle.cast(mask,dtype=y_true.dtype)masked_y_true=(y_true*mask_float).astype("int64")y_true=paddle.nn.functional.one_hot(num_classes=num_classes,x=masked_y_true).astype("int64")mask=paddle.cast(mask,dtype="int64")y_true=y_true.transpose(perm=[0,2,1])*mask.unsqueeze(axis=1)else:y_true=paddle.nn.functional.one_hot(num_classes=num_classes,x=y_true).astype("int64")y_true=y_true.transpose(perm=[0,2,1])ifself.mode==MULTILABEL_MODE:y_true=y_true.view(bs,num_classes,-1)y_pred=y_pred.view(bs,num_classes,-1)ifself.ignore_indexisnotNone:mask=y_true!=self.ignore_indexy_pred=y_pred*paddle.cast(mask,dtype="float32")y_true=y_true*paddle.cast(mask,dtype="float32")scores=soft_dice_score(y_pred,y_true.astype(dtype=y_pred.dtype),smooth=self.smooth,eps=self.eps,dims=dims,)ifself.log_loss:loss=-paddle.log(x=scores.clip(min=self.eps))else:loss=1.0-scoresmask=y_true.sum(axis=dims)>0loss*=mask.astype(loss.dtype)ifself.classesisnotNone:loss=loss[self.classes]returnloss.mean()
classJointLoss(paddle.nn.Layer):""" Wrap two loss functions into one. This class computes a weighted sum of two losses. """def__init__(self,first:paddle.nn.Layer,second:paddle.nn.Layer,first_weight=1.0,second_weight=1.0,):super().__init__()self.first=WeightedLoss(first,first_weight)self.second=WeightedLoss(second,second_weight)defforward(self,*input):returnself.first(*input)+self.second(*input)
checkpoint_callback=ModelCheckpoint(save_top_k=config.save_top_k,monitor=config.monitor,save_last=config.save_last,mode=config.monitor_mode,dirpath=config.weights_path,filename=config.weights_name,)logger=CSVLogger("lightning_logs",name=config.log_name)model=Supervision_Train(config)ifconfig.pretrained_ckpt_path:state_dict=paddle.load(config.pretrained_ckpt_path)model.set_state_dict(state_dict)paddle.set_device("gpu")optimizer,lr_scheduler=model.configure_optimizers()train_loader=model.train_dataloader()val_loader=model.val_dataloader()forepochinrange(config.max_epoch):print(f"Epoch {epoch+1}/{config.max_epoch}")model.train()train_losses=[]forbatch_idx,batchinenumerate(train_loader):output=model.training_step(batch,batch_idx)loss=output["loss"]train_losses.append(loss.item())loss.backward()optimizer.step()optimizer.clear_grad()ifbatch_idx%10==0:print(f" Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")train_log=model.on_train_epoch_end()train_log["loss"]=np.mean(train_losses)if(epoch+1)%config.check_val_every_n_epoch==0:model.eval()val_losses=[]forbatch_idx,batchinenumerate(val_loader):output=model.validation_step(batch,batch_idx)val_losses.append(output["loss_val"].item())val_log=model.on_validation_epoch_end()val_log["loss_val"]=np.mean(val_losses)checkpoint_callback.on_validation_epoch_end(None,model,val_log)logger.log_metrics(epoch,train_log,val_log)iflr_scheduler:lr_scheduler.step()ifconfig.resume_ckpt_pathandepoch==0:state=paddle.load(config.resume_ckpt_path)model.set_state_dict(state["model_state_dict"])optimizer.set_state_dict(state["optimizer_state_dict"])iflr_schedulerand"lr_scheduler_state_dict"instate:lr_scheduler.set_state_dict(state["lr_scheduler_state_dict"])print(f"Resumed training from checkpoint: {config.resume_ckpt_path}")if__name__=="__main__":main()