// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>. // // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. // +build trace package sqlite3 /* #ifndef USE_LIBSQLITE3 #include <sqlite3-binding.h> #else #include <sqlite3.h> #endif #include <stdlib.h> void stepTrampoline(sqlite3_context*, int, sqlite3_value**); void doneTrampoline(sqlite3_context*); int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x); */ import "C" import ( "errors" "fmt" "reflect" "strings" "sync" "unsafe" ) // Trace... constants identify the possible events causing callback invocation. // Values are same as the corresponding SQLite Trace Event Codes. const ( TraceStmt = C.SQLITE_TRACE_STMT TraceProfile = C.SQLITE_TRACE_PROFILE TraceRow = C.SQLITE_TRACE_ROW TraceClose = C.SQLITE_TRACE_CLOSE ) type TraceInfo struct { // Pack together the shorter fields, to keep the struct smaller. // On a 64-bit machine there would be padding // between EventCode and ConnHandle; having AutoCommit here is "free": EventCode uint32 AutoCommit bool ConnHandle uintptr // Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE: // identifier for a prepared statement: StmtHandle uintptr // Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT: // (1) either the unexpanded SQL text of the prepared statement, or // an SQL comment that indicates the invocation of a trigger; // (2) expanded SQL, if requested and if (1) is not an SQL comment. StmtOrTrigger string ExpandedSQL string // only if requested (TraceConfig.WantExpandedSQL = true) // filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE: // estimated number of nanoseconds that the prepared statement took to run: RunTimeNanosec int64 DBError Error } // TraceUserCallback gives the signature for a trace function // provided by the user (Go application programmer). // SQLite 3.14 documentation (as of September 2, 2016) // for SQL Trace Hook = sqlite3_trace_v2(): // The integer return value from the callback is currently ignored, // though this may change in future releases. Callback implementations // should return zero to ensure future compatibility. type TraceUserCallback func(TraceInfo) int type TraceConfig struct { Callback TraceUserCallback EventMask C.uint WantExpandedSQL bool } func fillDBError(dbErr *Error, db *C.sqlite3) { // See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016) dbErr.Code = ErrNo(C.sqlite3_errcode(db)) dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db)) dbErr.err = C.GoString(C.sqlite3_errmsg(db)) } func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) { if pStmt == nil { panic("No SQLite statement pointer in P arg of trace_v2 callback") } expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt)) if expSQLiteCStr == nil { fillDBError(&info.DBError, db) return } info.ExpandedSQL = C.GoString(expSQLiteCStr) } //export traceCallbackTrampoline func traceCallbackTrampoline( traceEventCode C.uint, // Parameter named 'C' in SQLite docs = Context given at registration: ctx unsafe.Pointer, // Parameter named 'P' in SQLite docs (Primary event data?): p unsafe.Pointer, // Parameter named 'X' in SQLite docs (eXtra event data?): xValue unsafe.Pointer) C.int { if ctx == nil { panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode)) } contextDB := (*C.sqlite3)(ctx) connHandle := uintptr(ctx) var traceConf TraceConfig var found bool if traceEventCode == TraceClose { // clean up traceMap: 'pop' means get and delete traceConf, found = popTraceMapping(connHandle) } else { traceConf, found = lookupTraceMapping(connHandle) } if !found { panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)", connHandle, traceEventCode)) } var info TraceInfo info.EventCode = uint32(traceEventCode) info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0) info.ConnHandle = connHandle switch traceEventCode { case TraceStmt: info.StmtHandle = uintptr(p) var xStr string if xValue != nil { xStr = C.GoString((*C.char)(xValue)) } info.StmtOrTrigger = xStr if !strings.HasPrefix(xStr, "--") { // Not SQL comment, therefore the current event // is not related to a trigger. // The user might want to receive the expanded SQL; // let's check: if traceConf.WantExpandedSQL { fillExpandedSQL(&info, contextDB, p) } } case TraceProfile: info.StmtHandle = uintptr(p) if xValue == nil { panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event") } info.RunTimeNanosec = *(*int64)(xValue) // sample the error //TODO: is it safe? is it useful? fillDBError(&info.DBError, contextDB) case TraceRow: info.StmtHandle = uintptr(p) case TraceClose: handle := uintptr(p) if handle != info.ConnHandle { panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.", handle, info.ConnHandle)) } default: // Pass unsupported events to the user callback (if configured); // let the user callback decide whether to panic or ignore them. } // Do not execute user callback when the event was not requested by user! // Remember that the Close event is always selected when // registering this callback trampoline with SQLite --- for cleanup. // In the future there may be more events forced to "selected" in SQLite // for the driver's needs. if traceConf.EventMask&traceEventCode == 0 { return 0 } r := 0 if traceConf.Callback != nil { r = traceConf.Callback(info) } return C.int(r) } type traceMapEntry struct { config TraceConfig } var traceMapLock sync.Mutex var traceMap = make(map[uintptr]traceMapEntry) func addTraceMapping(connHandle uintptr, traceConf TraceConfig) { traceMapLock.Lock() defer traceMapLock.Unlock() oldEntryCopy, found := traceMap[connHandle] if found { panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).", traceConf, connHandle, oldEntryCopy.config)) } traceMap[connHandle] = traceMapEntry{config: traceConf} fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle) } func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) { traceMapLock.Lock() defer traceMapLock.Unlock() entryCopy, found := traceMap[connHandle] return entryCopy.config, found } // 'pop' = get and delete from map before returning the value to the caller func popTraceMapping(connHandle uintptr) (TraceConfig, bool) { traceMapLock.Lock() defer traceMapLock.Unlock() entryCopy, found := traceMap[connHandle] if found { delete(traceMap, connHandle) fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config) } return entryCopy.config, found } // RegisterAggregator makes a Go type available as a SQLite aggregation function. // // Because aggregation is incremental, it's implemented in Go with a // type that has 2 methods: func Step(values) accumulates one row of // data into the accumulator, and func Done() ret finalizes and // returns the aggregate value. "values" and "ret" may be any type // supported by RegisterFunc. // // RegisterAggregator takes as implementation a constructor function // that constructs an instance of the aggregator type each time an // aggregation begins. The constructor must return a pointer to a // type, or an interface that implements Step() and Done(). // // The constructor function and the Step/Done methods may optionally // return an error in addition to their other return values. // // See _example/go_custom_funcs for a detailed example. func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { var ai aggInfo ai.constructor = reflect.ValueOf(impl) t := ai.constructor.Type() if t.Kind() != reflect.Func { return errors.New("non-function passed to RegisterAggregator") } if t.NumOut() != 1 && t.NumOut() != 2 { return errors.New("SQLite aggregator constructors must return 1 or 2 values") } if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { return errors.New("Second return value of SQLite function must be error") } if t.NumIn() != 0 { return errors.New("SQLite aggregator constructors must not have arguments") } agg := t.Out(0) switch agg.Kind() { case reflect.Ptr, reflect.Interface: default: return errors.New("SQlite aggregator constructor must return a pointer object") } stepFn, found := agg.MethodByName("Step") if !found { return errors.New("SQlite aggregator doesn't have a Step() function") } step := stepFn.Type if step.NumOut() != 0 && step.NumOut() != 1 { return errors.New("SQlite aggregator Step() function must return 0 or 1 values") } if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { return errors.New("type of SQlite aggregator Step() return value must be error") } stepNArgs := step.NumIn() start := 0 if agg.Kind() == reflect.Ptr { // Skip over the method receiver stepNArgs-- start++ } if step.IsVariadic() { stepNArgs-- } for i := start; i < start+stepNArgs; i++ { conv, err := callbackArg(step.In(i)) if err != nil { return err } ai.stepArgConverters = append(ai.stepArgConverters, conv) } if step.IsVariadic() { conv, err := callbackArg(t.In(start + stepNArgs).Elem()) if err != nil { return err } ai.stepVariadicConverter = conv // Pass -1 to sqlite so that it allows any number of // arguments. The call helper verifies that the minimum number // of arguments is present for variadic functions. stepNArgs = -1 } doneFn, found := agg.MethodByName("Done") if !found { return errors.New("SQlite aggregator doesn't have a Done() function") } done := doneFn.Type doneNArgs := done.NumIn() if agg.Kind() == reflect.Ptr { // Skip over the method receiver doneNArgs-- } if doneNArgs != 0 { return errors.New("SQlite aggregator Done() function must have no arguments") } if done.NumOut() != 1 && done.NumOut() != 2 { return errors.New("SQLite aggregator Done() function must return 1 or 2 values") } if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { return errors.New("second return value of SQLite aggregator Done() function must be error") } conv, err := callbackRet(done.Out(0)) if err != nil { return err } ai.doneRetConverter = conv ai.active = make(map[int64]reflect.Value) ai.next = 1 // ai must outlast the database connection, or we'll have dangling pointers. c.aggregators = append(c.aggregators, &ai) cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) opts := C.SQLITE_UTF8 if pure { opts |= C.SQLITE_DETERMINISTIC } rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) if rv != C.SQLITE_OK { return c.lastError() } return nil } // SetTrace installs or removes the trace callback for the given database connection. // It's not named 'RegisterTrace' because only one callback can be kept and called. // Calling SetTrace a second time on same database connection // overrides (cancels) any prior callback and all its settings: // event mask, etc. func (c *SQLiteConn) SetTrace(requested *TraceConfig) error { connHandle := uintptr(unsafe.Pointer(c.db)) _, _ = popTraceMapping(connHandle) if requested == nil { // The traceMap entry was deleted already by popTraceMapping(): // can disable all events now, no need to watch for TraceClose. err := c.setSQLiteTrace(0) return err } reqCopy := *requested // Disable potentially expensive operations // if their result will not be used. We are doing this // just in case the caller provided nonsensical input. if reqCopy.EventMask&TraceStmt == 0 { reqCopy.WantExpandedSQL = false } addTraceMapping(connHandle, reqCopy) // The callback trampoline function does cleanup on Close event, // regardless of the presence or absence of the user callback. // Therefore it needs the Close event to be selected: actualEventMask := uint(reqCopy.EventMask | TraceClose) err := c.setSQLiteTrace(actualEventMask) return err } func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error { rv := C.sqlite3_trace_v2(c.db, C.uint(sqliteEventMask), (*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)), unsafe.Pointer(c.db)) // Fourth arg is same as first: we are // passing the database connection handle as callback context. if rv != C.SQLITE_OK { return c.lastError() } return nil }