# set train dataloader configtrain_dataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":cfg.mode,"valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"sampler":{"name":"BatchSampler","drop_last":False,"shuffle":True,},"batch_size":cfg.TRAIN.batch_size,"num_workers":cfg.TRAIN.num_workers,}
# set constraintoutput_keys=cfg.MODEL.output_keyssup_constraint=ppsci.constraint.SupervisedConstraint(train_dataloader_cfg,ppsci.loss.FunctionalLoss(model.log_prob_loss),{key:(lambdaout,k=key:out[k])forkeyinoutput_keys},name="Sup_constraint",)constraint={sup_constraint.name:sup_constraint}
# set training hyper-parametersb_hidden_ch=cfg.get(cfg.data_name).b_hidden_cha_hidden_gnn=cfg.get(cfg.data_name).a_hidden_gnna_hidden_lin=cfg.get(cfg.data_name).a_hidden_linmask_row_size_list=list(cfg.get(cfg.data_name).mask_row_size_list)mask_row_stride_list=list(cfg.get(cfg.data_name).mask_row_stride_list)a_n_type=len(cfg.get(cfg.data_name).atomic_num_list)atomic_num_list=list(cfg.get(cfg.data_name).atomic_num_list)model_params=Hyperparameters(b_n_type=cfg.get(cfg.data_name).b_n_type,b_n_flow=cfg.get(cfg.data_name).b_n_flow,b_n_block=cfg.get(cfg.data_name).b_n_block,b_n_squeeze=cfg.get(cfg.data_name).b_n_squeeze,b_hidden_ch=b_hidden_ch,b_affine=True,b_conv_lu=cfg.get(cfg.data_name).b_conv_lu,a_n_node=cfg.get(cfg.data_name).a_n_node,a_n_type=a_n_type,a_hidden_gnn=a_hidden_gnn,a_hidden_lin=a_hidden_lin,a_n_flow=cfg.get(cfg.data_name).a_n_flow,a_n_block=cfg.get(cfg.data_name).a_n_block,mask_row_size_list=mask_row_size_list,mask_row_stride_list=mask_row_stride_list,a_affine=True,learn_dist=cfg.get(cfg.data_name).learn_dist,seed=cfg.seed,noise_scale=cfg.get(cfg.data_name).noise_scale,)logger.info("Model params:\n"+tabulate(model_params.print()))
# general settingsmode:train# running mode: train/evaldata_name:qm9# data select:qm9/zinc250kseed:1output_dir:${hydra:run.dir}log_freq:20# set training hyper-parametersqm9:b_n_flow:10b_n_block:1b_hidden_ch:[128,128]a_n_flow:27a_n_block:1a_hidden_gnn:[64]a_hidden_lin:[128,64]mask_row_size_list:[1]mask_row_stride_list:[1]learn_dist:Truenoise_scale:0.6b_conv_lu:1atomic_num_list:[6,7,8,9,0]b_n_type:4b_n_squeeze:3a_n_node:9valid_idx:valid_idx_qm9.jsonlabel_keys:['A','B','C','mu','alpha','homo','lumo','gap','r2','zpve','U0','U','H','G','Cv']smiles_col:SMILES1zinc250k:b_n_flow:10b_n_block:1b_hidden_ch:[512,512]a_n_flow:38a_n_block:1a_hidden_gnn:[256]a_hidden_lin:[512,64]mask_row_size_list:[1]mask_row_stride_list:[1]learn_dist:Truenoise_scale:0.6b_conv_lu:2atomic_num_list:[6,7,8,9,15,16,17,35,53,0]b_n_type:4b_n_squeeze:19a_n_node:38valid_idx:valid_idx_zinc.jsonlabel_keys:['logP','qed','SAS']smiles_col:smiles# set data pathFILE_PATH:./datasets/moflow# model settingsMODEL:input_keys:["nodes","edges"]output_keys:["output","sum_log_det"]
# set eval dataloader configeval_dataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":"eval","valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"batch_size":cfg.EVAL.batch_size,}# set validatorsup_validator=ppsci.validate.SupervisedValidator(eval_dataloader_cfg,ppsci.loss.FunctionalLoss(model.log_prob_loss),{key:(lambdaout,k=key:out[k])forkeyinoutput_keys},metric={"Valid":ppsci.metric.FunctionalMetric(eval_func(model,cfg.EVAL.batch_size,atomic_num_list))},name="Sup_Validator",)validator={sup_validator.name:sup_validator}
dataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":cfg.mode,"valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"sampler":{"name":"BatchSampler","drop_last":False,"shuffle":True,},"batch_size":cfg.EVAL.batch_size,"num_workers":cfg.EVAL.num_workers,}test=ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])dataloader_cfg["dataset"].update({"mode":"train"})train=ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])logger.info("{} in total, {} training data, {} testing data, {} batchsize, train/batchsize {}".format(len(train)+len(test),len(train),len(test),batch_size,len(train)/batch_size,))ifcfg.EVAL.reconstruct:train_dataloader=ppsci.data.build_dataloader(train,dataloader_cfg)reconstruction_rate_list=[]max_iter=len(train_dataloader)input_keys=cfg.MODEL.input_keysoutput_keys=cfg.MODEL.output_keysfori,batchinenumerate(train_dataloader,start=0):output_dict=model(batch[0])x=batch[0][input_keys[0]]adj=batch[0][input_keys[1]]z=output_dict[output_keys[0]]z0=z[0].reshape([tuple(z[0].shape)[0],-1])z1=z[1].reshape([tuple(z[1].shape)[0],-1])adj_rev,x_rev=model.reverse(paddle.concat(x=[z0,z1],axis=1))reverse_smiles=adj_to_smiles(adj_rev.cpu(),x_rev.cpu(),atomic_num_list)train_smiles=adj_to_smiles(adj.cpu(),x.cpu(),atomic_num_list)lb=np.array([int(a!=b)fora,binzip(train_smiles,reverse_smiles)])idx=np.where(lb)[0]iflen(idx)>0:forkinidx:logger.info("{}, train: {}, reverse: {}".format(i*batch_size+k,train_smiles[k],reverse_smiles[k]))reconstruction_rate=1.0-lb.mean()reconstruction_rate_list.append(reconstruction_rate)logger.message("iter/total: {}/{}, reconstruction_rate:{}".format(i,max_iter,reconstruction_rate))reconstruction_rate_total=np.array(reconstruction_rate_list).mean()logger.message("reconstruction_rate for all the train data:{} in {}".format(reconstruction_rate_total,len(train)))exit(0)ifcfg.EVAL.int2point:inputs=train.inputlabels=train.labelitems=[]foridxinrange(len(train)):input_item=[value[idx]forkey,valueininputs.items()]label_item=[value[idx]forkey,valueinlabels.items()]item=input_item+label_itemitem=transform_fn(item)items.append(item)items=np.array(items,dtype=object).Tinputs={key:np.stack(items[i],axis=0)fori,keyinenumerate(inputs)}mol_smiles=Nonegen_dir=osp.join(cfg.output_dir,cfg.EVAL_mode)logger.message("Dump figure in {}".format(gen_dir))ifnotosp.exists(gen_dir):os.makedirs(gen_dir)forseedinrange(cfg.EVAL.inter_times):filepath=osp.join(gen_dir,"2points_interpolation-2point_molecules_seed{}".format(seed))visualize_interpolation_between_2_points(filepath,model,mol_smiles=mol_smiles,mols_per_row=15,n_interpolation=50,atomic_num_list=atomic_num_list,seed=seed,true_data=inputs,data_name=cfg.data_name,)exit(0)ifcfg.EVAL.intgrid:inputs=train.inputlabels=train.labelitems=[]foridxinrange(len(train)):input_item=[value[idx]forkey,valueininputs.items()]label_item=[value[idx]forkey,valueinlabels.items()]item=input_item+label_itemitem=transform_fn(item)items.append(item)items=np.array(items,dtype=object).Tinputs={key:np.stack(items[i],axis=0)fori,keyinenumerate(inputs)}mol_smiles=Nonegen_dir=os.path.join(cfg.output_dir,cfg.EVAL_mode)logger.message("Dump figure in {}".format(gen_dir))ifnotos.path.exists(gen_dir):os.makedirs(gen_dir)forseedinrange(cfg.EVAL.inter_times):filepath=os.path.join(gen_dir,"generated_interpolation-grid_molecules_seed{}".format(seed))visualize_interpolation(filepath,model,mol_smiles=mol_smiles,mols_per_row=9,delta=cfg.EVAL.delta,atomic_num_list=atomic_num_list,seed=seed,true_data=inputs,data_name=cfg.data_name,keep_duplicate=True,)filepath=os.path.join(gen_dir,"generated_interpolation-grid_molecules_seed{}_unique".format(seed),)visualize_interpolation(filepath,model,mol_smiles=mol_smiles,mols_per_row=9,delta=cfg.EVAL.delta,atomic_num_list=atomic_num_list,seed=seed,true_data=inputs,data_name=cfg.data_name,keep_duplicate=False,)exit(0)inputs=train.input
# set dataloader configdataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":cfg.mode,"valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"sampler":{"name":"BatchSampler","drop_last":False,"shuffle":True,},"batch_size":cfg.OPTIMIZE.batch_size,"num_workers":0,}# set modelmodel_cfg=dict(cfg.MODEL)model_cfg.update({"hyper_params":model_params})model=ppsci.arch.MoFlowNet(**model_cfg)ppsci.utils.save_load.load_pretrain(model,path=cfg.TRAIN.pretrained_model_path)model_prop_cfg=dict(cfg.MODEL_Prop)model_prop_cfg.update({"model":model,"hidden_size":hidden,})property_model=ppsci.arch.MoFlowProp(**model_prop_cfg)train=ppsci.data.dataset.build_dataset(dataloader_cfg["dataset"])train_dataloader=ppsci.data.build_dataloader(train,dataloader_cfg)train_idx=train.train_idxproperty_model_path=osp.join(cfg.output_dir,"{}_model.pdparams".format(cfg.OPTIMIZE.property_name))ifnotosp.exists(property_model_path):logger.message("Training regression model over molecular embedding:")property_csv_path=osp.join(cfg.FILE_PATH,"{}_property.csv".format(cfg.data_name))prop_list=load_property_csv(property_csv_path,normalize=True)train_prop=[prop_list[i]foriintrain_idx]# test_prop = [prop_list[i] for i in valid_idx]N=len(train)property_model=fit_model(property_model,train_dataloader,train_prop,N,property_name=cfg.OPTIMIZE.property_name,max_epochs=cfg.OPTIMIZE.max_epochs,learning_rate=cfg.OPTIMIZE.learning_rate,weight_decay=cfg.OPTIMIZE.weight_decay,)logger.message("saving {} regression model to: {}".format(cfg.OPTIMIZE.property_name,property_model_path))paddle.save(obj=property_model.state_dict(),path=property_model_path)else:logger.message("Loading trained regression model for optimization")property_csv_path=osp.join(cfg.FILE_PATH,"{}_property.csv".format(cfg.data_name))prop_list=load_property_csv(property_csv_path,normalize=True)train_prop=[prop_list[i]foriintrain_idx]# test_prop = [prop_list[i] for i in valid_idx]logger.message("loading {} regression model from: {}".format(cfg.OPTIMIZE.property_name,property_model_path))state_dict=paddle.load(path=property_model_path)property_model.set_state_dict(state_dict)property_model.eval()model.eval()ifcfg.OPTIMIZE.topscore:logger.message("Finding top score:")find_top_score_smiles(model,property_model,cfg.data_name,cfg.OPTIMIZE.property_name,train_prop,cfg.OPTIMIZE.topk,atomic_num_list,cfg.OPTIMIZE.debug,cfg.output_dir,)ifcfg.OPTIMIZE.consopt:logger.message("Constrained optimization:")constrain_optimization_smiles(model,property_model,cfg.data_name,cfg.OPTIMIZE.property_name,train_prop,cfg.OPTIMIZE.topk,atomic_num_list,cfg.OPTIMIZE.debug,cfg.output_dir,sim_cutoff=cfg.OPTIMIZE.sim_cutoff,)
# optimize settingsOPTIMIZE:property_name:plogp# qed/plogpbatch_size:256topk:800debug:falsetopscore:falsemax_epochs:3learning_rate:0.001weight_decay:1e-2hidden:[16]# Hidden dimension list for output regressiontemperature:1.0consopt:true
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromosimportpathasospimporthydraimportmoflow_transformimportnumpyasnpimportpaddlefrommoflow_utilsimportHyperparametersfrommoflow_utilsimportcheck_validityfromomegaconfimportDictConfigfromtabulateimporttabulateimportppscifromppsci.utilsimportloggerdefinfer(model,batch_size=20,temp=0.7,z_mu=None,true_adj=None):"""generate mols Args: model (object): Generated eval Moflownet model batch_size (int, optional): Batch size during evaling per GPU. Defaults to 20. temp (float, optional): temperature of the gaussian distribution. Defaults to 0.7. z_mu (int, optional): latent vector of a molecule. Defaults to None. true_adj (paddle.Tensor, optional): True Adjacency. Defaults to None. Returns: Tuple(paddle.Tensor, paddle.Tensor): Adjacency and nodes """z_dim=model.b_size+model.a_sizemu=np.zeros(z_dim)sigma_diag=np.ones(z_dim)ifmodel.hyper_params.learn_dist:iflen(model.ln_var)==1:sigma_diag=np.sqrt(np.exp(model.ln_var.item()))*sigma_diageliflen(model.ln_var)==2:sigma_diag[:model.b_size]=(np.sqrt(np.exp(model.ln_var[0].item()))*sigma_diag[:model.b_size])sigma_diag[model.b_size+1:]=(np.sqrt(np.exp(model.ln_var[1].item()))*sigma_diag[model.b_size+1:])sigma=temp*sigma_diagwithpaddle.no_grad():ifz_muisnotNone:mu=z_musigma=0.01*np.eye(z_dim)z=np.random.normal(mu,sigma,(batch_size,z_dim))z=paddle.to_tensor(data=z).astype(paddle.get_default_dtype())adj,x=model.reverse(z,true_adj=true_adj)returnadj,xclasseval_func:def__init__(self,metrics_mode,batch_size,atomic_num_list,*args,):super().__init__()self.metrics_mode=metrics_modeself.batch_size=batch_sizeself.atomic_num_list=atomic_num_listdef__call__(self,output_dict,label_dict,):self.metrics_mode.eval()adj,x=infer(self.metrics_mode,self.batch_size)validity_info=check_validity(adj,x,self.atomic_num_list)self.metrics_mode.train()results=dict()results["valid"]=validity_info["valid_ratio"]results["unique"]=validity_info["unique_ratio"]results["abs_unique"]=validity_info["abs_unique_ratio"]returnresultsdeftrain(cfg:DictConfig):# set training hyper-parametersb_hidden_ch=cfg.get(cfg.data_name).b_hidden_cha_hidden_gnn=cfg.get(cfg.data_name).a_hidden_gnna_hidden_lin=cfg.get(cfg.data_name).a_hidden_linmask_row_size_list=list(cfg.get(cfg.data_name).mask_row_size_list)mask_row_stride_list=list(cfg.get(cfg.data_name).mask_row_stride_list)a_n_type=len(cfg.get(cfg.data_name).atomic_num_list)atomic_num_list=list(cfg.get(cfg.data_name).atomic_num_list)model_params=Hyperparameters(b_n_type=cfg.get(cfg.data_name).b_n_type,b_n_flow=cfg.get(cfg.data_name).b_n_flow,b_n_block=cfg.get(cfg.data_name).b_n_block,b_n_squeeze=cfg.get(cfg.data_name).b_n_squeeze,b_hidden_ch=b_hidden_ch,b_affine=True,b_conv_lu=cfg.get(cfg.data_name).b_conv_lu,a_n_node=cfg.get(cfg.data_name).a_n_node,a_n_type=a_n_type,a_hidden_gnn=a_hidden_gnn,a_hidden_lin=a_hidden_lin,a_n_flow=cfg.get(cfg.data_name).a_n_flow,a_n_block=cfg.get(cfg.data_name).a_n_block,mask_row_size_list=mask_row_size_list,mask_row_stride_list=mask_row_stride_list,a_affine=True,learn_dist=cfg.get(cfg.data_name).learn_dist,seed=cfg.seed,noise_scale=cfg.get(cfg.data_name).noise_scale,)logger.info("Model params:\n"+tabulate(model_params.print()))# set transformsifcfg.data_name=="qm9":transform_fn=moflow_transform.transform_fnelifcfg.data_name=="zinc250k":transform_fn=moflow_transform.transform_fn_zinc250k# set select eval datavalid_idx_path=osp.join(cfg.FILE_PATH,cfg.get(cfg.data_name).valid_idx)valid_idx=moflow_transform.get_val_ids(valid_idx_path,cfg.data_name)# set train dataloader configtrain_dataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":cfg.mode,"valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"sampler":{"name":"BatchSampler","drop_last":False,"shuffle":True,},"batch_size":cfg.TRAIN.batch_size,"num_workers":cfg.TRAIN.num_workers,}# set modelmodel_cfg=dict(cfg.MODEL)model_cfg.update({"hyper_params":model_params})model=ppsci.arch.MoFlowNet(**model_cfg)# set constraintoutput_keys=cfg.MODEL.output_keyssup_constraint=ppsci.constraint.SupervisedConstraint(train_dataloader_cfg,ppsci.loss.FunctionalLoss(model.log_prob_loss),{key:(lambdaout,k=key:out[k])forkeyinoutput_keys},name="Sup_constraint",)constraint={sup_constraint.name:sup_constraint}# set iters_per_epoch by dataloader lengthITERS_PER_EPOCH=len(sup_constraint.data_loader)# init optimizer and lr scheduleroptimizer=ppsci.optimizer.Adam(cfg.TRAIN.learning_rate)(model)# set eval dataloader configeval_dataloader_cfg={"dataset":{"name":"MOlFLOWDataset","file_path":cfg.FILE_PATH,"data_name":cfg.data_name,"mode":"eval","valid_idx":valid_idx,"input_keys":cfg.MODEL.input_keys,"label_keys":cfg.get(cfg.data_name).label_keys,"smiles_col":cfg.get(cfg.data_name).smiles_col,"transform_fn":transform_fn,},"batch_size":cfg.EVAL.batch_size,}# set validatorsup_validator=ppsci.validate.SupervisedValidator(eval_dataloader_cfg,ppsci.loss.FunctionalLoss(model.log_prob_loss),{key:(lambdaout,k=key:out[k])forkeyinoutput_keys},metric={"Valid":ppsci.metric.FunctionalMetric(eval_func(model,cfg.EVAL.batch_size,atomic_num_list))},name="Sup_Validator",)validator={sup_validator.name:sup_validator}# initialize solversolver=ppsci.solver.Solver(model,constraint,cfg.output_dir,optimizer,None,cfg.TRAIN.epochs,ITERS_PER_EPOCH,seed=cfg.seed,validator=validator,save_freq=cfg.TRAIN.save_freq,eval_during_train=cfg.TRAIN.eval_during_train,eval_freq=cfg.TRAIN.eval_freq,compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,eval_with_no_grad=cfg.EVAL.eval_with_no_grad,)# train modelsolver.train()# validation for trainingsolver.eval()@hydra.main(version_base=None,config_path="./conf",config_name="moflow_train.yaml")defmain(cfg:DictConfig):train(cfg)if__name__=="__main__":main()