Skip to content

Commit

Permalink
[FSTORE-1314] Support defining similar function in embedding index (#…
Browse files Browse the repository at this point in the history
…1783)

* use sim function

* add license

* update license

* refactor getOpensearchFunction

(cherry picked from commit 5bf8980)
  • Loading branch information
kennethmhc committed May 17, 2024
1 parent 96cfe19 commit 83ec6f6
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ protected String createMapping(String prefix, Collection<EmbeddingFeature> embed
" }";
String embeddingFieldString = " \"%s\": {\n" +
" \"type\": \"knn_vector\",\n" +
" \"dimension\": %d\n" +
" \"dimension\": %d,\n" +
" \"method\": {\n" +
" \"name\": \"hnsw\",\n" +
" \"space_type\": \"%s\",\n" +
" \"engine\": \"nmslib\"\n" +
" }\n" +
" }";
String fieldString = " \"%s\": {\n" +
" \"type\": \"%s\"\n" +
Expand All @@ -307,7 +312,9 @@ protected String createMapping(String prefix, Collection<EmbeddingFeature> embed

for (EmbeddingFeature feature : embeddingFeatures) {
fieldMapping.add(String.format(
embeddingFieldString, prefix + feature.getName(), feature.getDimension()));
embeddingFieldString,
prefix + feature.getName(), feature.getDimension(),
feature.getSimilarityFunctionType().getOpensearchFunction()));
}
for (FeatureGroupFeatureDTO feature : features) {
if (!embeddingFeatureNames.contains(feature.getName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.hops.hopsworks.common.featurestore.featuregroup;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature;
import lombok.AllArgsConstructor;
import lombok.Getter;
Expand All @@ -30,13 +31,13 @@ public class EmbeddingFeatureDTO {
@Getter
private String name;
@Getter
private String similarityFunctionType;
private SimilarityFunctionType similarityFunctionType;
@Getter
private Integer dimension;
@Getter
private ModelDto model;

public EmbeddingFeatureDTO(String name, String similarityFunctionType, Integer dimension) {
public EmbeddingFeatureDTO(String name, SimilarityFunctionType similarityFunctionType, Integer dimension) {
this.name = name;
this.similarityFunctionType = similarityFunctionType;
this.dimension = dimension;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import java.util.Set;
import java.util.stream.Collectors;

import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.COSINE;
import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.DOT_PRODUCT;
import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.L2_NORM;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand Down Expand Up @@ -71,8 +74,9 @@ public void setup() throws Exception {
@Test
public void testCreateIndex() {
List<EmbeddingFeature> embeddingFeatures = new ArrayList<>();
embeddingFeatures.add(new EmbeddingFeature(null, "vector", 512, "l2"));
embeddingFeatures.add(new EmbeddingFeature(null, "vector2", 128, "l2"));
embeddingFeatures.add(new EmbeddingFeature(null, "vector", 512, L2_NORM));
embeddingFeatures.add(new EmbeddingFeature(null, "vector2", 128, COSINE));
embeddingFeatures.add(new EmbeddingFeature(null, "vector3", 64, DOT_PRODUCT));
List<FeatureGroupFeatureDTO> features = new ArrayList<>();
Set<String> offlineTypes =
FeaturestoreConstants.SUGGESTED_HIVE_FEATURE_TYPES.stream().map(type -> type.split(" ")[0])
Expand All @@ -83,6 +87,7 @@ public void testCreateIndex() {
}
features.add(new FeatureGroupFeatureDTO("vector", "ARRAY<DOUBLE>"));
features.add(new FeatureGroupFeatureDTO("vector2", "ARRAY<DOUBLE>"));
features.add(new FeatureGroupFeatureDTO("vector3", "ARRAY<DOUBLE>"));


String expectedMapping = "{\n" +
Expand All @@ -96,11 +101,30 @@ public void testCreateIndex() {
" \"properties\": {\n" +
" \"vector\": {\n" +
" \"type\": \"knn_vector\",\n" +
" \"dimension\": 512\n" +
" \"dimension\": 512,\n" +
" \"method\": {\n" +
" \"name\": \"hnsw\",\n" +
" \"space_type\": \"l2\",\n" +
" \"engine\": \"nmslib\"\n" +
" }\n" +
" },\n" +
" \"vector2\": {\n" +
" \"type\": \"knn_vector\",\n" +
" \"dimension\": 128\n" +
" \"dimension\": 128,\n" +
" \"method\": {\n" +
" \"name\": \"hnsw\",\n" +
" \"space_type\": \"cosinesimil\",\n" +
" \"engine\": \"nmslib\"\n" +
" }\n" +
" },\n" +
" \"vector3\": {\n" +
" \"type\": \"knn_vector\",\n" +
" \"dimension\": 64,\n" +
" \"method\": {\n" +
" \"name\": \"hnsw\",\n" +
" \"space_type\": \"innerproduct\",\n" +
" \"engine\": \"nmslib\"\n" +
" }\n" +
" }";

for (String offlineType : offlineTypes) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Arrays;
import java.util.List;

import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.L2_NORM;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -706,9 +707,9 @@ public void testVerifyEmbeddingFeatureExist_pass() throws FeaturestoreException
FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2");
FeatureGroupFeatureDTO feature3 = new FeatureGroupFeatureDTO("feature3");

EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3);

List<FeatureGroupFeatureDTO> features = Arrays.asList(feature1, feature2, feature3);
List<EmbeddingFeatureDTO> embeddingFeatures =
Expand All @@ -732,10 +733,10 @@ public void testVerifyEmbeddingFeatureExist_fail() throws FeaturestoreException
FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2");
FeatureGroupFeatureDTO feature3 = new FeatureGroupFeatureDTO("feature3");

EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature3 =
new EmbeddingFeatureDTO("feature4", "l2_norm", 3); // this does not exist in feature group
new EmbeddingFeatureDTO("feature4", L2_NORM, 3); // this does not exist in feature group

List<FeatureGroupFeatureDTO> features = Arrays.asList(feature1, feature2, feature3);
List<EmbeddingFeatureDTO> embeddingFeatures =
Expand All @@ -753,9 +754,9 @@ public void testVerifyEmbeddingFeatureExist_fail() throws FeaturestoreException

@Test
public void testVerifyEmbeddingIndex_pass() throws FeaturestoreException {
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3);

List<EmbeddingFeatureDTO> embeddingFeatures =
Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3);
Expand All @@ -775,9 +776,9 @@ public void testVerifyEmbeddingIndex_pass() throws FeaturestoreException {

@Test
public void testVerifyEmbeddingIndex_fail() throws FeaturestoreException {
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3);
EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3);

List<EmbeddingFeatureDTO> embeddingFeatures =
Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3);
Expand All @@ -796,7 +797,7 @@ public void testVerifyEmbeddingIndex_fail() throws FeaturestoreException {
}

private FeaturegroupDTO createFeaturegroupDtoWithIndexName(String indexName) {
EmbeddingFeatureDTO embeddingFeature = new EmbeddingFeatureDTO("feature3", "l2_norm", 3);
EmbeddingFeatureDTO embeddingFeature = new EmbeddingFeatureDTO("feature3", L2_NORM, 3);

List<EmbeddingFeatureDTO> embeddingFeatures =
Arrays.asList(embeddingFeature);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import javax.persistence.Basic;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.EnumType;
import javax.persistence.Enumerated;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
Expand All @@ -48,7 +50,8 @@ public class EmbeddingFeature implements Serializable {
@Column
private Integer dimension;
@Column(name = "similarity_function_type")
private String similarityFunctionType;
@Enumerated(EnumType.STRING)
private SimilarityFunctionType similarityFunctionType;
@JoinColumn(name = "model_version_id", referencedColumnName = "id")
@OneToOne
private ModelVersion modelVersion;
Expand All @@ -57,15 +60,15 @@ public EmbeddingFeature() {
}

public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
String similarityFunctionType) {
SimilarityFunctionType similarityFunctionType) {
this.embedding = embedding;
this.name = name;
this.dimension = dimension;
this.similarityFunctionType = similarityFunctionType;
}

public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
String similarityFunctionType, ModelVersion modelVersion) {
SimilarityFunctionType similarityFunctionType, ModelVersion modelVersion) {
this.embedding = embedding;
this.name = name;
this.dimension = dimension;
Expand All @@ -74,7 +77,7 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
}

public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension,
String similarityFunctionType) {
SimilarityFunctionType similarityFunctionType) {
this.id = id;
this.embedding = embedding;
this.name = name;
Expand All @@ -98,7 +101,7 @@ public Integer getDimension() {
return dimension;
}

public String getSimilarityFunctionType() {
public SimilarityFunctionType getSimilarityFunctionType() {
return similarityFunctionType;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* This file is part of Hopsworks
* Copyright (C) 2024, Hopsworks AB. All rights reserved
*
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
* the GNU Affero General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
* PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see <https://www.gnu.org/licenses/>.
*/

package io.hops.hopsworks.persistence.entity.featurestore.featuregroup;

import com.fasterxml.jackson.annotation.JsonProperty;

public enum SimilarityFunctionType {

@JsonProperty(value = "l2_norm")
L2_NORM("l2"),
@JsonProperty(value = "cosine")
COSINE("cosinesimil"),
@JsonProperty(value = "dot_product")
DOT_PRODUCT("innerproduct");

private final String opensearchFunction;

SimilarityFunctionType(String opensearchFunction) {
this.opensearchFunction = opensearchFunction;
}

public String getOpensearchFunction() {
return opensearchFunction;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* This file is part of Hopsworks
* Copyright (C) 2024, Hopsworks AB. All rights reserved
*
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
* the GNU Affero General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
* PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see <https://www.gnu.org/licenses/>.
*/

package io.hops.hopsworks.persistence.entity.featurestore.featuregroup;

import javax.persistence.AttributeConverter;
import javax.persistence.Converter;

@Converter(autoApply = true)
public class SimilarityFunctionTypeConverter implements AttributeConverter<SimilarityFunctionType, String> {
@Override
public String convertToDatabaseColumn(SimilarityFunctionType attribute) {
return attribute.name().toLowerCase(); // Convert enum value to lowercase string
}

@Override
public SimilarityFunctionType convertToEntityAttribute(String dbData) {
return SimilarityFunctionType.valueOf(dbData.toUpperCase()); // Convert lowercase string to uppercase enum value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionTypeConverter</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.ondemand.OnDemandFeaturegroup</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.CachedFeaturegroup</class>
<class>io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.FeatureGroupCommit</class>
Expand Down

0 comments on commit 83ec6f6

Please sign in to comment.