Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@
import io.stargate.sgv2.jsonapi.config.feature.ApiFeature;
import io.stargate.sgv2.jsonapi.config.feature.ApiFeatures;
import io.stargate.sgv2.jsonapi.config.feature.FeaturesConfig;
import io.stargate.sgv2.jsonapi.logging.LoggingMDCContext;
import io.stargate.sgv2.jsonapi.metrics.CommandFeatures;
import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter;
import io.stargate.sgv2.jsonapi.service.cqldriver.CQLSessionCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.*;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory;
import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProviderFactory;
import io.stargate.sgv2.jsonapi.service.schema.DatabaseSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.KeyspaceSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectType;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
Expand All @@ -29,13 +36,16 @@
* context for a specific request call {@link BuilderSupplier#getBuilder(SchemaObject)} to get a
* {@link BuilderSupplier.Builder} to configure the context for the request.
*
* <p>
* <p><b>NOTE:</b> When {@link BuilderSupplier.Builder#build()} is called it will call {@link
* #addToMDC()} so that the context is added to the logging MDC for the duration of the request. The
* context must be closed via {@link #close()} to remove it from the MDC, this should be done at the
* last possible time in the resource handler so all log messages have the context.
*
* @param <SchemaT> The schema object type that this context is for. There are times we need to lock
* this down to the specific type, if so use the "as" methods such as {@link
* CommandContext#asCollectionContext()}
*/
public class CommandContext<SchemaT extends SchemaObject> {
public class CommandContext<SchemaT extends SchemaObject> implements LoggingMDCContext {

// Common for all instances
private final JsonProcessingMetricsReporter jsonProcessingMetricsReporter;
Expand All @@ -47,11 +57,15 @@ public class CommandContext<SchemaT extends SchemaObject> {

// Request specific
private final SchemaT schemaObject;
private final RequestTracing requestTracing;
private final RequestContext requestContext;
private final EmbeddingProvider
embeddingProvider; // to be removed later, this is a single provider
private final String commandName; // TODO: remove the command name, but it is used in 14 places
private final RequestContext requestContext;
private RequestTracing requestTracing;

// both per request list of objects that want to update the logging MDC context,
// add to this list in the ctor. See {@link #addToMDC()} and {@link #removeFromMDC()}
private final List<LoggingMDCContext> loggingMDCContexts = new ArrayList<>();

// see accessors
private FindAndRerankCommand.HybridLimits hybridLimits;
Expand All @@ -77,19 +91,23 @@ private CommandContext(
RerankingProviderFactory rerankingProviderFactory,
MeterRegistry meterRegistry) {

this.schemaObject = schemaObject;
this.embeddingProvider = embeddingProvider;
this.commandName = commandName;
this.requestContext = requestContext;

this.jsonProcessingMetricsReporter = jsonProcessingMetricsReporter;
// Common for all instances
this.cqlSessionCache = cqlSessionCache;
this.commandConfig = commandConfig;
this.embeddingProviderFactory = embeddingProviderFactory;
this.jsonProcessingMetricsReporter = jsonProcessingMetricsReporter;
this.meterRegistry = meterRegistry;
this.rerankingProviderFactory = rerankingProviderFactory;

// Request specific
this.embeddingProvider = embeddingProvider; // to be removed later, this is a single provider
this.requestContext = requestContext;
this.schemaObject = schemaObject;
this.commandName = commandName; // TODO: remove the command name, but it is used in 14 places
this.apiFeatures = apiFeatures;
this.meterRegistry = meterRegistry;

this.loggingMDCContexts.add(this.requestContext);
this.loggingMDCContexts.add(this.schemaObject.identifier());

var anyTracing =
apiFeatures().isFeatureEnabled(ApiFeature.REQUEST_TRACING)
Expand Down Expand Up @@ -191,41 +209,59 @@ public MeterRegistry meterRegistry() {
}

public boolean isCollectionContext() {
return schemaObject().type() == CollectionSchemaObject.TYPE;
return schemaObject().type() == SchemaObjectType.COLLECTION;
}

@SuppressWarnings("unchecked")
public CommandContext<CollectionSchemaObject> asCollectionContext() {
checkSchemaObjectType(CollectionSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.COLLECTION);
return (CommandContext<CollectionSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<TableSchemaObject> asTableContext() {
checkSchemaObjectType(TableSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.TABLE);
return (CommandContext<TableSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<KeyspaceSchemaObject> asKeyspaceContext() {
checkSchemaObjectType(KeyspaceSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.KEYSPACE);
return (CommandContext<KeyspaceSchemaObject>) this;
}

@SuppressWarnings("unchecked")
public CommandContext<DatabaseSchemaObject> asDatabaseContext() {
checkSchemaObjectType(DatabaseSchemaObject.TYPE);
checkSchemaObjectType(SchemaObjectType.DATABASE);
return (CommandContext<DatabaseSchemaObject>) this;
}

private void checkSchemaObjectType(SchemaObject.SchemaObjectType expectedType) {
private void checkSchemaObjectType(SchemaObjectType expectedType) {
Preconditions.checkArgument(
schemaObject().type() == expectedType,
"SchemaObject type actual was %s expected was %s ",
schemaObject().type(),
expectedType);
}

@Override
public void addToMDC() {
loggingMDCContexts.forEach(LoggingMDCContext::addToMDC);
}

@Override
public void removeFromMDC() {
loggingMDCContexts.forEach(LoggingMDCContext::removeFromMDC);
}

/**
* NOTE: Not using AutoCloseable because it created a lot of linting warnings, we only want to
* close this in the request resource handler.
*/
public void close() throws Exception {
removeFromMDC();
}

/**
* Configure the BuilderSupplier with resources and config that will be used for all the {@link
* CommandContext} that will be created. Then called {@link
Expand Down Expand Up @@ -341,18 +377,21 @@ public CommandContext<SchemaT> build() {
Objects.requireNonNull(commandName, "commandName must not be null");
Objects.requireNonNull(requestContext, "requestContext must not be null");

return new CommandContext<>(
schemaObject,
embeddingProvider,
commandName,
requestContext,
jsonProcessingMetricsReporter,
cqlSessionCache,
commandConfig,
apiFeatures,
embeddingProviderFactory,
rerankingProviderFactory,
meterRegistry);
var context =
new CommandContext<>(
schemaObject,
embeddingProvider,
commandName,
requestContext,
jsonProcessingMetricsReporter,
cqlSessionCache,
commandConfig,
apiFeatures,
embeddingProviderFactory,
rerankingProviderFactory,
meterRegistry);
context.addToMDC();
return context;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
* The schema object a command can be called against.
*
* <p>Example: creteTable runs against the Keyspace , so target is the Keyspace aaron 13 - nove -
* 2024 - not using the {@link
* io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject.SchemaObjectType} because this
* also needs the SYSTEM value, and the schema object design prob needs improvement
* 2024 - not using the {@link io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObjectType}
* because this also needs the SYSTEM value, and the schema object design prob needs improvement
*/
public enum CommandTarget {
COLLECTION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.exception.FilterException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.shredding.collections.DocumentId;
import io.stargate.sgv2.jsonapi.service.shredding.collections.JsonExtensionType;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.SortDefinition;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.exception.SortException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
import java.util.Map;
import java.util.Objects;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.filter.*;
import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.MapComponentDesc;
import io.stargate.sgv2.jsonapi.exception.FilterException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.operation.filters.table.MapSetListFilterComponent;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiTypeName;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil;
import java.util.*;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortClause;
import io.stargate.sgv2.jsonapi.api.model.command.clause.sort.SortExpression;
import io.stargate.sgv2.jsonapi.exception.SortException;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDef;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiColumnDefContainer;
import io.stargate.sgv2.jsonapi.service.schema.tables.TableSchemaObject;
import io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil;
import io.stargate.sgv2.jsonapi.util.JsonUtil;
import java.util.ArrayList;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.stargate.sgv2.jsonapi.api.v1;

import static io.stargate.sgv2.jsonapi.config.constants.DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD;
import static io.stargate.sgv2.jsonapi.util.CqlIdentifierUtil.cqlIdentifierFromUserInput;

import io.micrometer.core.instrument.MeterRegistry;
import io.smallrye.mutiny.Uni;
Expand Down Expand Up @@ -31,13 +32,14 @@
import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableCommandResultSupplier;
import io.stargate.sgv2.jsonapi.metrics.JsonProcessingMetricsReporter;
import io.stargate.sgv2.jsonapi.service.cqldriver.CqlSessionCacheSupplier;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaCache;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.SchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.VectorColumnDefinition;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProvider;
import io.stargate.sgv2.jsonapi.service.embedding.operation.EmbeddingProviderFactory;
import io.stargate.sgv2.jsonapi.service.processor.MeteredCommandProcessor;
import io.stargate.sgv2.jsonapi.service.reranking.operation.RerankingProviderFactory;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectCacheSupplier;
import io.stargate.sgv2.jsonapi.service.schema.SchemaObjectType;
import io.stargate.sgv2.jsonapi.service.schema.UnscopedSchemaObjectIdentifier;
import jakarta.inject.Inject;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotEmpty;
Expand All @@ -60,13 +62,16 @@
import org.eclipse.microprofile.openapi.annotations.security.SecurityRequirement;
import org.eclipse.microprofile.openapi.annotations.tags.Tag;
import org.jboss.resteasy.reactive.RestResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Path(CollectionResource.BASE_PATH)
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@SecurityRequirement(name = OpenApiConstants.SecuritySchemes.TOKEN)
@Tag(ref = "Documents")
public class CollectionResource {
private static final Logger LOGGER = LoggerFactory.getLogger(CollectionResource.class);

public static final String BASE_PATH = GeneralResource.BASE_PATH + "/{keyspace}/{collection}";

Expand All @@ -75,20 +80,23 @@ public class CollectionResource {
// TODO remove apiFeatureConfig as a property after cleanup for how we get schema from cache
@Inject private FeaturesConfig apiFeatureConfig;
@Inject private RequestContext requestContext;
@Inject private SchemaCache schemaCache;

private final SchemaObjectCacheSupplier schemaObjectCacheSupplier;
private final CommandContext.BuilderSupplier contextBuilderSupplier;
private final EmbeddingProviderFactory embeddingProviderFactory;
private final MeteredCommandProcessor meteredCommandProcessor;

@Inject
public CollectionResource(
SchemaObjectCacheSupplier schemaObjectCacheSupplier,
MeteredCommandProcessor meteredCommandProcessor,
MeterRegistry meterRegistry,
JsonProcessingMetricsReporter jsonProcessingMetricsReporter,
CqlSessionCacheSupplier sessionCacheSupplier,
EmbeddingProviderFactory embeddingProviderFactory,
RerankingProviderFactory rerankingProviderFactory) {

this.schemaObjectCacheSupplier = schemaObjectCacheSupplier;
this.embeddingProviderFactory = embeddingProviderFactory;
this.meteredCommandProcessor = meteredCommandProcessor;

Expand Down Expand Up @@ -198,12 +206,15 @@ public Uni<RestResponse<CommandResult>> postCommand(
@NotNull @Valid CollectionCommand command,
@PathParam("keyspace") @NotEmpty String keyspace,
@PathParam("collection") @NotEmpty String collection) {
return schemaCache
.getSchemaObject(
requestContext,
keyspace,
collection,
CommandType.DDL.equals(command.commandName().getCommandType()))

var name =
new UnscopedSchemaObjectIdentifier.DefaultKeyspaceScopedName(
cqlIdentifierFromUserInput(keyspace), cqlIdentifierFromUserInput(collection));
var forceRefresh = CommandType.DDL.equals(command.commandName().getCommandType());

return schemaObjectCacheSupplier
.get()
.getTableBased(requestContext, name, requestContext.userAgent(), forceRefresh)
.onItemOrFailure()
.transformToUni(
(schemaObject, throwable) -> {
Expand All @@ -219,19 +230,17 @@ public Uni<RestResponse<CommandResult>> postCommand(
// otherwise use generic for now
return Uni.createFrom().item(new ThrowableCommandResultSupplier(error));
} else {
// TODO No need for the else clause here, simplify

// TODO: This needs to change, currently it is only checking if there is vectorize
// for the $vector column in a collection

VectorColumnDefinition vectorColDef = null;
if (schemaObject.type() == SchemaObject.SchemaObjectType.COLLECTION) {
if (schemaObject.type() == SchemaObjectType.COLLECTION) {
vectorColDef =
schemaObject
.vectorConfig()
.getColumnDefinition(VECTOR_EMBEDDING_TEXT_FIELD)
.orElse(null);
} else if (schemaObject.type() == SchemaObject.SchemaObjectType.TABLE) {
} else if (schemaObject.type() == SchemaObjectType.TABLE) {
vectorColDef =
schemaObject
.vectorConfig()
Expand Down Expand Up @@ -262,7 +271,20 @@ public Uni<RestResponse<CommandResult>> postCommand(
.withRequestContext(requestContext)
.build();

return meteredCommandProcessor.processCommand(commandContext, command);
return meteredCommandProcessor
.processCommand(commandContext, command)
.onTermination()
.invoke(
() -> {
try {
commandContext.close();
} catch (Exception e) {
LOGGER.error(
"Error closing the command context for requestContext={}",
requestContext,
e);
}
});
}
})
.map(commandResult -> commandResult.toRestResponse());
Expand Down
Loading
Loading