Skip to content

Commit

Permalink
[HWORKS-936] Explicit model provenance backend (#1513)
Browse files Browse the repository at this point in the history
* [HWORKS-938] explicit provenance - model (#1725)

* [HWORKS-938] explicit provenance - model

* fix collector key bug

* fix

* cleanup

* fixes

* fixes

* fix build

* remove unnecessary imports from tests

---------

Co-authored-by: Alexandru Ormenisan <[email protected]>
Co-authored-by: Alexandru Ormenisan <[email protected]>
Co-authored-by: bubriks <[email protected]>

* fix comunity

* [HWORKS-936][APPEND] explicit provenance - model (#1743)

* init

* add expand users to models

* mini fix

---------

Co-authored-by: Alex Ormenisan <[email protected]>
Co-authored-by: Alexandru Ormenisan <[email protected]>
Co-authored-by: Alexandru Ormenisan <[email protected]>
  • Loading branch information
4 people authored Mar 25, 2024
1 parent 8867f72 commit 70e2683
Show file tree
Hide file tree
Showing 19 changed files with 684 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,44 @@
import io.hops.hopsworks.api.modelregistry.models.dto.ModelDTO;
import io.hops.hopsworks.common.dao.project.ProjectFacade;
import io.hops.hopsworks.common.dataset.DatasetController;
import io.hops.hopsworks.common.featurestore.FeaturestoreFacade;
import io.hops.hopsworks.common.featurestore.featureview.FeatureViewController;
import io.hops.hopsworks.common.featurestore.featureview.FeatureViewFacade;
import io.hops.hopsworks.common.featurestore.trainingdatasets.TrainingDatasetFacade;
import io.hops.hopsworks.common.hdfs.DistributedFileSystemOps;
import io.hops.hopsworks.common.hdfs.DistributedFsService;
import io.hops.hopsworks.common.hdfs.HdfsUsersController;
import io.hops.hopsworks.common.hdfs.Utils;
import io.hops.hopsworks.common.provenance.explicit.ModelLinkController;
import io.hops.hopsworks.common.util.AccessController;
import io.hops.hopsworks.common.util.Settings;
import io.hops.hopsworks.exceptions.DatasetException;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.exceptions.GenericException;

import io.hops.hopsworks.exceptions.ProjectException;
import io.hops.hopsworks.persistence.entity.dataset.Dataset;
import io.hops.hopsworks.persistence.entity.featurestore.Featurestore;
import io.hops.hopsworks.persistence.entity.featurestore.featureview.FeatureView;
import io.hops.hopsworks.persistence.entity.featurestore.trainingdataset.TrainingDataset;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.persistence.entity.provenance.ModelLink;
import io.hops.hopsworks.persistence.entity.user.Users;
import io.hops.hopsworks.restutils.RESTCodes;
import org.javatuples.Pair;

import javax.ejb.EJB;
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import java.util.logging.Level;
import java.util.logging.Logger;

@Stateless
@TransactionAttribute(TransactionAttributeType.NEVER)
public class ModelUtils {
private static final Logger LOGGER = Logger.getLogger(ModelUtils.class.getName());

@EJB
private AccessController accessCtrl;
Expand All @@ -53,6 +67,16 @@ public class ModelUtils {
private HdfsUsersController hdfsUsersController;
@EJB
private DistributedFsService dfs;
@EJB
private FeaturestoreFacade fsFacade;
@EJB
private FeatureViewFacade fvFacade;
@EJB
private FeatureViewController fvCtrl;
@EJB
private TrainingDatasetFacade tdFacade;
@EJB
private ModelLinkController modelLinkCtrl;

public String getModelsDatasetPath(Project userProject, Project modelRegistryProject) {
String modelsPath = Utils.getProjectPath(userProject.getName()) + Settings.HOPS_MODELS_DATASET + "/";
Expand Down Expand Up @@ -126,9 +150,91 @@ public String getModelFullPath(Project modelRegistryProject, String modelName, I
return Utils.getProjectPath(modelRegistryProject.getName()) +
Settings.HOPS_MODELS_DATASET + "/" + modelName + "/" + modelVersion;
}

public String getModelFullPath(ModelVersion modelVersion) {
return getModelFullPath(modelVersion.getModel().getProject(), modelVersion.getModel().getName(),
modelVersion.getVersion());
}

public String[] getModelNameAndVersion(String mlId) {
int splitIndex = mlId.lastIndexOf("_");
return new String[]{mlId.substring(0, splitIndex), mlId.substring(splitIndex + 1)};
}

public ModelLink createExplicitProvenanceLink(ModelVersion model, ModelDTO modelDTO) {
FeatureView fv = null;
Integer tdVersion = null;
if(modelDTO.getTrainingDataset() != null) {
if(modelDTO.getTrainingDataset().getId() != null) {
TrainingDataset td = tdFacade.find(modelDTO.getTrainingDataset().getId());
if(td == null) {
LOGGER.info("training dataset not found - cannot create model provenance link");
}
return modelLinkCtrl.createParentLink(model, td);
}
Pair<String, Integer> fvNameVersion = splitTdName(modelDTO.getTrainingDataset().getName());
if(fvNameVersion == null) {
LOGGER.info("training dataset name is wrong - cannot create model provenance link");
return null;
}
fv = getFeatureView(modelDTO.getTrainingDataset().getFeaturestoreId(),
fvNameVersion.getValue0(), fvNameVersion.getValue1());
tdVersion = modelDTO.getTrainingDataset().getVersion();
} else if(modelDTO.getFeatureView() != null) {
if(modelDTO.getFeatureView().getId() != null) {
fv = fvFacade.find(modelDTO.getFeatureView().getId());
} else {
fv = getFeatureView(modelDTO.getFeatureView().getFeaturestoreId(),
modelDTO.getFeatureView().getName(), modelDTO.getFeatureView().getVersion());
}
tdVersion = modelDTO.getTrainingDatasetVersion();
}
if(fv == null) {
return null;
}
if(tdVersion == null) {
LOGGER.info("training dataset version is missing - cannot create model provenance link");
return null;
}
try {
TrainingDataset trainingDataset = tdFacade.findByFeatureViewAndVersion(fv, modelDTO.getTrainingDatasetVersion());
return modelLinkCtrl.createParentLink(model, trainingDataset);
} catch (FeaturestoreException e) {
LOGGER.info("training dataset not found - cannot create model provenance link");
return null;
}
}

private Pair<String, Integer> splitTdName(String tdName) {
int split = tdName.lastIndexOf("_");
if(split == -1) {
return null;
}
try {
String fvName = tdName.substring(0, split);
int fvVersion;
try {
fvVersion = Integer.parseInt(tdName.substring(split + 1));
} catch (NumberFormatException e) {
return null;
}
return Pair.with(fvName, fvVersion);
} catch (IndexOutOfBoundsException e) {
return null;
}
}

private FeatureView getFeatureView(Integer fsId, String fvName, Integer fvVersion) {
Featurestore fvFS = fsFacade.findById(fsId);
if(fvFS == null) {
LOGGER.info("feature view featurestore not found - cannot create model provenance link");
return null;
}
try {
return fvCtrl.getByNameVersionAndFeatureStore(fvName, fvVersion, fvFS);
} catch (FeaturestoreException e) {
LOGGER.info("feature view not found - cannot create model provenance link");
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import javax.ejb.Stateless;
import javax.ejb.TransactionAttribute;
import javax.ejb.TransactionAttributeType;
import javax.ws.rs.core.UriBuilder;
import javax.ws.rs.core.UriInfo;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -68,29 +69,31 @@ public class ModelsBuilder {
private UsersBuilder usersBuilder;

public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project modelRegistryProject) {
dto.setHref(uriInfo.getBaseUriBuilder()
.path(ResourceRequest.Name.PROJECT.toString().toLowerCase())
.path(Integer.toString(userProject.getId()))
.path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase())
.path(Integer.toString(modelRegistryProject.getId()))
.path(ResourceRequest.Name.MODELS.toString().toLowerCase())
.build());
dto.setHref(modelUri(uriInfo, userProject, modelRegistryProject).build());
return dto;
}

public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project modelRegistryProject,
ModelVersion modelVersion) {
dto.setHref(uriInfo.getBaseUriBuilder()
.path(ResourceRequest.Name.PROJECT.toString().toLowerCase())
.path(Integer.toString(userProject.getId()))
.path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase())
.path(Integer.toString(modelRegistryProject.getId()))
.path(ResourceRequest.Name.MODELS.toString().toLowerCase())
.path(modelVersion.getModel().getName() + "_" + modelVersion.getVersion())
.build());
dto.setHref(modelVersionUri(uriInfo, userProject, modelRegistryProject, modelVersion).build());
return dto;
}

public UriBuilder modelUri(UriInfo uriInfo, Project userProject, Project modelRegistryProject) {
return uriInfo.getBaseUriBuilder()
.path(ResourceRequest.Name.PROJECT.toString().toLowerCase())
.path(Integer.toString(userProject.getId()))
.path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase())
.path(Integer.toString(modelRegistryProject.getId()))
.path(ResourceRequest.Name.MODELS.toString().toLowerCase());
}

public UriBuilder modelVersionUri(UriInfo uriInfo, Project userProject, Project modelRegistryProject,
ModelVersion modelVersion) {
return modelUri(uriInfo, userProject, modelRegistryProject)
.path(modelVersion.getModel().getName() + "_" + modelVersion.getVersion());
}

public ModelDTO expand(ModelDTO dto, ResourceRequest resourceRequest) {
if (resourceRequest != null && resourceRequest.contains(ResourceRequest.Name.MODELS)) {
dto.setExpand(true);
Expand Down Expand Up @@ -195,4 +198,15 @@ public ModelDTO build(UriInfo uriInfo,
}
return modelDTO;
}

public ModelDTO build(UriInfo uriInfo,
ResourceRequest resourceRequest,
Users user,
Project userProject,
ModelVersion modelVersion)
throws GenericException, ModelRegistryException, MetadataException, FeatureStoreMetadataException,
DatasetException {
return build(uriInfo, resourceRequest, user, userProject, modelVersion.getModel().getProject(), modelVersion,
modelUtils.getModelFullPath(modelVersion));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.hops.hopsworks.api.jwt.JWTHelper;
import io.hops.hopsworks.api.modelregistry.models.dto.ModelDTO;
import io.hops.hopsworks.api.modelregistry.models.tags.ModelTagResource;
import io.hops.hopsworks.api.provenance.ModelProvenanceResource;
import io.hops.hopsworks.api.util.Pagination;
import io.hops.hopsworks.common.api.ResourceRequest;
import io.hops.hopsworks.common.hdfs.DistributedFsService;
Expand Down Expand Up @@ -90,6 +91,8 @@ public class ModelsResource {
private ModelUtils modelUtils;
@Inject
private ModelTagResource tagResource;
@Inject
private ModelProvenanceResource provenanceResource;

private Project userProject;

Expand Down Expand Up @@ -229,6 +232,7 @@ public Response put(@PathParam("id") String id,
accessor = modelUtils.getModelsAccessor(user, userProject, modelProject,
experimentProject);
ModelVersion modelVersion = modelsController.createModelVersion(accessor, modelDTO, jobName, kernelId);
modelUtils.createExplicitProvenanceLink(modelVersion, modelDTO);
ModelDTO dto = modelsBuilder.build(uriInfo, new ResourceRequest(ResourceRequest.Name.MODELS), user, userProject,
modelRegistryProject, modelVersion, modelUtils.getModelFullPath(modelProject, modelVersion.getModel().getName(),
modelVersion.getVersion()));
Expand All @@ -251,4 +255,15 @@ public ModelTagResource tags(@ApiParam(value = "Id of the model", required = tru
this.tagResource.setModel(modelVersion);
return this.tagResource;
}

@Path("/{id}/provenance")
public ModelProvenanceResource provenance(@ApiParam(value = "Id of the model", required = true)
@PathParam("id") String id)
throws ModelRegistryException, ProvenanceException {
this.provenanceResource.setAccessProject(userProject);
this.provenanceResource.setModelRegistry(modelRegistryProject);
ModelVersion modelVersion = modelsController.getModel(modelRegistryProject, id);
this.provenanceResource.setModelVersion(modelVersion);
return this.provenanceResource;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.hops.hopsworks.api.dataset.inode.InodeDTO;
import io.hops.hopsworks.api.user.UserDTO;
import io.hops.hopsworks.common.featurestore.featureview.FeatureViewDTO;
import io.hops.hopsworks.common.tags.TagsDTO;
import io.hops.hopsworks.common.api.RestDTO;
import io.hops.hopsworks.common.featurestore.trainingdatasets.TrainingDatasetDTO;
Expand Down Expand Up @@ -78,6 +79,10 @@ public ModelDTO() {
private Integer modelRegistryId;

private TagsDTO tags;

private FeatureViewDTO featureView;

private int trainingDatasetVersion;

private String type = "modelDTO";

Expand Down Expand Up @@ -234,7 +239,23 @@ public UserDTO getCreator() {
public void setCreator(UserDTO creator) {
this.creator = creator;
}


public FeatureViewDTO getFeatureView() {
return featureView;
}

public void setFeatureView(FeatureViewDTO featureView) {
this.featureView = featureView;
}

public int getTrainingDatasetVersion() {
return trainingDatasetVersion;
}

public void setTrainingDatasetVersion(int trainingDatasetVersion) {
this.trainingDatasetVersion = trainingDatasetVersion;
}

@Override
public String toString() {
return "ModelDTO{" +
Expand All @@ -256,6 +277,8 @@ public String toString() {
", trainingDataset=" + trainingDataset +
", modelRegistryId=" + modelRegistryId +
", tags=" + tags +
", featureView=" + featureView +
", trainingDatasetVersion=" + trainingDatasetVersion +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.exceptions.GenericException;
import io.hops.hopsworks.exceptions.MetadataException;
import io.hops.hopsworks.exceptions.ModelRegistryException;
import io.hops.hopsworks.exceptions.ProvenanceException;
import io.hops.hopsworks.exceptions.ServiceException;
import io.hops.hopsworks.jwt.annotation.JWTRequired;
Expand Down Expand Up @@ -138,7 +139,7 @@ public Response getLinks(
@Context HttpServletRequest req,
@Context SecurityContext sc)
throws GenericException, FeaturestoreException, DatasetException, ServiceException, MetadataException,
FeatureStoreMetadataException, IOException {
FeatureStoreMetadataException, IOException, ModelRegistryException {
Users user = jwtHelper.getUserPrincipal(sc);
ResourceRequest resourceRequest = new ResourceRequest(ResourceRequest.Name.PROVENANCE);
resourceRequest.setExpansions(explicitProvenanceExpansionBeanParam.getResources());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.exceptions.GenericException;
import io.hops.hopsworks.exceptions.MetadataException;
import io.hops.hopsworks.exceptions.ModelRegistryException;
import io.hops.hopsworks.exceptions.ProvenanceException;
import io.hops.hopsworks.exceptions.ServiceException;
import io.hops.hopsworks.jwt.annotation.JWTRequired;
Expand Down Expand Up @@ -132,7 +133,7 @@ public Response getLinks(
@Context HttpServletRequest req,
@Context SecurityContext sc)
throws GenericException, FeaturestoreException, DatasetException, ServiceException, MetadataException,
FeatureStoreMetadataException, IOException {
FeatureStoreMetadataException, IOException, ModelRegistryException {
Users user = jwtHelper.getUserPrincipal(sc);
ResourceRequest resourceRequest = new ResourceRequest(ResourceRequest.Name.PROVENANCE);
resourceRequest.setExpansions(explicitProvenanceExpansionBeanParam.getResources());
Expand Down
Loading

0 comments on commit 70e2683

Please sign in to comment.