|
| 1 | +package datadog.trace.llmobs.writer.ddintake |
| 2 | + |
| 3 | +import com.fasterxml.jackson.databind.ObjectMapper |
| 4 | +import datadog.communication.serialization.ByteBufferConsumer |
| 5 | +import datadog.communication.serialization.FlushingBuffer |
| 6 | +import datadog.communication.serialization.msgpack.MsgPackWriter |
| 7 | +import datadog.trace.api.llmobs.LLMObs |
| 8 | +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes |
| 9 | +import datadog.trace.bootstrap.instrumentation.api.Tags |
| 10 | +import datadog.trace.common.writer.ListWriter |
| 11 | +import datadog.trace.core.test.DDCoreSpecification |
| 12 | +import org.msgpack.jackson.dataformat.MessagePackFactory |
| 13 | +import spock.lang.Shared |
| 14 | + |
| 15 | +import java.nio.ByteBuffer |
| 16 | +import java.nio.channels.WritableByteChannel |
| 17 | + |
| 18 | +class LLMObsSpanMapperTest extends DDCoreSpecification { |
| 19 | + |
| 20 | + @Shared |
| 21 | + ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()) |
| 22 | + |
| 23 | + def "test LLMObsSpanMapper serialization"() { |
| 24 | + setup: |
| 25 | + def mapper = new LLMObsSpanMapper() |
| 26 | + def tracer = tracerBuilder().writer(new ListWriter()).build() |
| 27 | + |
| 28 | + |
| 29 | + // Create a real LLMObs span using the tracer |
| 30 | + def llmSpan = tracer.buildSpan("chat-completion") |
| 31 | + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) |
| 32 | + .withTag("_ml_obs_tag.model_name", "gpt-4") |
| 33 | + .withTag("_ml_obs_tag.model_provider", "openai") |
| 34 | + .withTag("_ml_obs_metric.input_tokens", 50) |
| 35 | + .withTag("_ml_obs_metric.output_tokens", 25) |
| 36 | + .withTag("_ml_obs_metric.total_tokens", 75) |
| 37 | + .start() |
| 38 | + |
| 39 | + llmSpan.setSpanType(InternalSpanTypes.LLMOBS) |
| 40 | + |
| 41 | + def inputMessages = [LLMObs.LLMMessage.from("user", "Hello, what's the weather like?")] |
| 42 | + def outputMessages = [LLMObs.LLMMessage.from("assistant", "I'll help you check the weather.")] |
| 43 | + llmSpan.setTag("_ml_obs_tag.input", inputMessages) |
| 44 | + llmSpan.setTag("_ml_obs_tag.output", outputMessages) |
| 45 | + llmSpan.setTag("_ml_obs_tag.metadata", [temperature: 0.7, max_tokens: 100]) |
| 46 | + |
| 47 | + llmSpan.finish() |
| 48 | + |
| 49 | + def trace = [llmSpan] |
| 50 | + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() |
| 51 | + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(1024, sink)) |
| 52 | + |
| 53 | + when: |
| 54 | + packer.format(trace, mapper) |
| 55 | + packer.flush() |
| 56 | + |
| 57 | + then: |
| 58 | + sink.captured != null |
| 59 | + def payload = mapper.newPayload() |
| 60 | + payload.withBody(1, sink.captured) |
| 61 | + def channel = new ByteArrayOutputStream() |
| 62 | + payload.writeTo(new WritableByteChannel() { |
| 63 | + @Override |
| 64 | + int write(ByteBuffer src) throws IOException { |
| 65 | + def bytes = new byte[src.remaining()] |
| 66 | + src.get(bytes) |
| 67 | + channel.write(bytes) |
| 68 | + return bytes.length |
| 69 | + } |
| 70 | + |
| 71 | + @Override |
| 72 | + boolean isOpen() { return true } |
| 73 | + |
| 74 | + @Override |
| 75 | + void close() throws IOException { } |
| 76 | + }) |
| 77 | + def result = objectMapper.readValue(channel.toByteArray(), Map) |
| 78 | + |
| 79 | + then: |
| 80 | + result.containsKey("event_type") |
| 81 | + result["event_type"] == "span" |
| 82 | + result.containsKey("_dd.stage") |
| 83 | + result["_dd.stage"] == "raw" |
| 84 | + result.containsKey("spans") |
| 85 | + result["spans"] instanceof List |
| 86 | + result["spans"].size() == 1 |
| 87 | + |
| 88 | + def spanData = result["spans"][0] |
| 89 | + spanData["name"] == "chat-completion" |
| 90 | + spanData.containsKey("span_id") |
| 91 | + spanData.containsKey("trace_id") |
| 92 | + spanData.containsKey("start_ns") |
| 93 | + spanData.containsKey("duration") |
| 94 | + spanData["error"] == 0 |
| 95 | + |
| 96 | + spanData.containsKey("meta") |
| 97 | + spanData["meta"]["span.kind"] == "llm" |
| 98 | + spanData["meta"].containsKey("input.messages") |
| 99 | + spanData["meta"].containsKey("output.messages") |
| 100 | + spanData["meta"].containsKey("metadata") |
| 101 | + |
| 102 | + spanData.containsKey("metrics") |
| 103 | + spanData["metrics"]["input_tokens"] == 50.0 |
| 104 | + spanData["metrics"]["output_tokens"] == 25.0 |
| 105 | + spanData["metrics"]["total_tokens"] == 75.0 |
| 106 | + |
| 107 | + spanData.containsKey("tags") |
| 108 | + spanData["tags"].contains("language:jvm") |
| 109 | + } |
| 110 | + |
| 111 | + static class CapturingByteBufferConsumer implements ByteBufferConsumer { |
| 112 | + |
| 113 | + ByteBuffer captured |
| 114 | + |
| 115 | + @Override |
| 116 | + void accept(int messageCount, ByteBuffer buffer) { |
| 117 | + captured = buffer |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | +} |
0 commit comments