Bring odoo-go to ~70%: read_group, record rules, admin, sessions
Phase 1: read_group/web_read_group with SQL GROUP BY, aggregates (sum/avg/min/max/count/array_agg/sum_currency), date granularity, M2O groupby resolution to [id, display_name]. Phase 2: Record rules with domain_force parsing (Python literal parser), global AND + group OR merging. Domain operators: child_of, parent_of, any, not any compiled to SQL hierarchy/EXISTS queries. Phase 3: Button dispatch via /web/dataset/call_button, method return values interpreted as actions. Payment register wizard (account.payment.register) for sale→invoice→pay flow. Phase 4: ir.filters, ir.default, product fields expanded, SO line product_id onchange, ir_model+ir_model_fields DB seeding. Phase 5: CSV export (/web/export/csv), attachment upload/download via ir.attachment, fields_get with aggregator hints. Admin/System: Session persistence (PostgreSQL-backed), ir.config_parameter with get_param/set_param, ir.cron, ir.logging, res.lang, res.config.settings with company-related fields, Settings form view. Technical menu with Views/Actions/Parameters/Security/Logging sub-menus. User change_password, preferences. Password never exposed in UI/API. Bugfixes: false→nil for varchar/int fields, int32 in toInt64, call_button route with trailing slash, create_invoices returns action, search view always included, get_formview_action, name_create, ir.http stub. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -105,6 +105,7 @@ func Not(node DomainNode) Domain {
|
||||
// Mirrors: odoo/orm/domains.py Domain._to_sql()
|
||||
type DomainCompiler struct {
|
||||
model *Model
|
||||
env *Environment // For operators that need DB access (child_of, parent_of, any, not any)
|
||||
params []interface{}
|
||||
joins []joinClause
|
||||
aliasCounter int
|
||||
@@ -193,11 +194,35 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
|
||||
|
||||
case Condition:
|
||||
return dc.compileCondition(n)
|
||||
|
||||
case domainGroup:
|
||||
// domainGroup wraps a sub-domain as a single node.
|
||||
// Compile it recursively as a full domain.
|
||||
subSQL, subParams, err := dc.compileDomainGroup(Domain(n))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_ = subParams // params already appended inside compileDomainGroup
|
||||
return subSQL, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected domain node at position %d: %v", pos, node)
|
||||
}
|
||||
|
||||
// compileDomainGroup compiles a sub-domain that was wrapped via domainGroup.
|
||||
// It reuses the same DomainCompiler (sharing params and joins) so parameter
|
||||
// indices stay consistent with the outer query.
|
||||
func (dc *DomainCompiler) compileDomainGroup(sub Domain) (string, []interface{}, error) {
|
||||
if len(sub) == 0 {
|
||||
return "TRUE", nil, nil
|
||||
}
|
||||
sql, err := dc.compileNodes(sub, 0)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return sql, nil, nil
|
||||
}
|
||||
|
||||
func (dc *DomainCompiler) compileCondition(c Condition) (string, error) {
|
||||
if !validOperators[c.Operator] {
|
||||
return "", fmt.Errorf("invalid operator: %q", c.Operator)
|
||||
@@ -285,6 +310,18 @@ func (dc *DomainCompiler) compileSimpleCondition(column, operator string, value
|
||||
dc.params = append(dc.params, value)
|
||||
return fmt.Sprintf("%q ILIKE $%d", column, paramIdx), nil
|
||||
|
||||
case "child_of":
|
||||
return dc.compileHierarchyOp(column, value, true)
|
||||
|
||||
case "parent_of":
|
||||
return dc.compileHierarchyOp(column, value, false)
|
||||
|
||||
case "any":
|
||||
return dc.compileAnyOp(column, value, false)
|
||||
|
||||
case "not any":
|
||||
return dc.compileAnyOp(column, value, true)
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("unhandled operator: %q", operator)
|
||||
}
|
||||
@@ -396,6 +433,272 @@ func (dc *DomainCompiler) compileQualifiedCondition(qualifiedColumn, operator st
|
||||
}
|
||||
}
|
||||
|
||||
// compileHierarchyOp implements child_of / parent_of by querying the DB for hierarchy IDs.
|
||||
// Mirrors: odoo/orm/domains.py _expression._get_hierarchy_ids
|
||||
//
|
||||
// - child_of: finds all descendants via parent_id traversal, then "id" IN (...)
|
||||
// - parent_of: finds all ancestors via parent_id traversal, then "id" IN (...)
|
||||
//
|
||||
// Requires dc.env to be set for DB access.
|
||||
func (dc *DomainCompiler) compileHierarchyOp(column string, value Value, isChildOf bool) (string, error) {
|
||||
if dc.env == nil {
|
||||
return "", fmt.Errorf("child_of/parent_of requires Environment on DomainCompiler")
|
||||
}
|
||||
|
||||
// Normalize the root ID(s)
|
||||
rootIDs := toInt64Slice(value)
|
||||
if len(rootIDs) == 0 {
|
||||
return "FALSE", nil
|
||||
}
|
||||
|
||||
table := dc.model.Table()
|
||||
var allIDs map[int64]bool
|
||||
|
||||
if isChildOf {
|
||||
// child_of: find all descendants (including roots) via parent_id
|
||||
allIDs = make(map[int64]bool)
|
||||
queue := make([]int64, len(rootIDs))
|
||||
copy(queue, rootIDs)
|
||||
for _, id := range rootIDs {
|
||||
allIDs[id] = true
|
||||
}
|
||||
|
||||
for len(queue) > 0 {
|
||||
// Build placeholders for current batch
|
||||
placeholders := make([]string, len(queue))
|
||||
args := make([]interface{}, len(queue))
|
||||
for i, id := range queue {
|
||||
args[i] = id
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+1)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`SELECT "id" FROM %q WHERE "parent_id" IN (%s)`,
|
||||
table, strings.Join(placeholders, ", "),
|
||||
)
|
||||
|
||||
rows, err := dc.env.tx.Query(dc.env.ctx, query, args...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("child_of query: %w", err)
|
||||
}
|
||||
|
||||
var nextQueue []int64
|
||||
for rows.Next() {
|
||||
var childID int64
|
||||
if err := rows.Scan(&childID); err != nil {
|
||||
rows.Close()
|
||||
return "", err
|
||||
}
|
||||
if !allIDs[childID] {
|
||||
allIDs[childID] = true
|
||||
nextQueue = append(nextQueue, childID)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
queue = nextQueue
|
||||
}
|
||||
} else {
|
||||
// parent_of: find all ancestors (including roots) via parent_id
|
||||
allIDs = make(map[int64]bool)
|
||||
queue := make([]int64, len(rootIDs))
|
||||
copy(queue, rootIDs)
|
||||
for _, id := range rootIDs {
|
||||
allIDs[id] = true
|
||||
}
|
||||
|
||||
for len(queue) > 0 {
|
||||
placeholders := make([]string, len(queue))
|
||||
args := make([]interface{}, len(queue))
|
||||
for i, id := range queue {
|
||||
args[i] = id
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+1)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`SELECT "parent_id" FROM %q WHERE "id" IN (%s) AND "parent_id" IS NOT NULL`,
|
||||
table, strings.Join(placeholders, ", "),
|
||||
)
|
||||
|
||||
rows, err := dc.env.tx.Query(dc.env.ctx, query, args...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parent_of query: %w", err)
|
||||
}
|
||||
|
||||
var nextQueue []int64
|
||||
for rows.Next() {
|
||||
var parentID int64
|
||||
if err := rows.Scan(&parentID); err != nil {
|
||||
rows.Close()
|
||||
return "", err
|
||||
}
|
||||
if !allIDs[parentID] {
|
||||
allIDs[parentID] = true
|
||||
nextQueue = append(nextQueue, parentID)
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
queue = nextQueue
|
||||
}
|
||||
}
|
||||
|
||||
if len(allIDs) == 0 {
|
||||
return "FALSE", nil
|
||||
}
|
||||
|
||||
// Build "id" IN (1, 2, 3, ...) with parameters
|
||||
paramIdx := len(dc.params) + 1
|
||||
placeholders := make([]string, 0, len(allIDs))
|
||||
for id := range allIDs {
|
||||
dc.params = append(dc.params, id)
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", paramIdx))
|
||||
paramIdx++
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%q IN (%s)", column, strings.Join(placeholders, ", ")), nil
|
||||
}
|
||||
|
||||
// compileAnyOp implements 'any' and 'not any' operators.
|
||||
// Mirrors: odoo/orm/domains.py for 'any' / 'not any' operators
|
||||
//
|
||||
// - any: EXISTS (SELECT 1 FROM comodel WHERE comodel.fk = model.id AND <subdomain>)
|
||||
// - not any: NOT EXISTS (...)
|
||||
//
|
||||
// The value must be a Domain (sub-domain) to apply on the comodel.
|
||||
func (dc *DomainCompiler) compileAnyOp(column string, value Value, negate bool) (string, error) {
|
||||
// Resolve the field to find the comodel
|
||||
f := dc.model.GetField(column)
|
||||
if f == nil {
|
||||
return "", fmt.Errorf("any/not any: field %q not found on %s", column, dc.model.Name())
|
||||
}
|
||||
|
||||
comodel := Registry.Get(f.Comodel)
|
||||
if comodel == nil {
|
||||
return "", fmt.Errorf("any/not any: comodel %q not found for field %q", f.Comodel, column)
|
||||
}
|
||||
|
||||
// The value should be a Domain (sub-domain for the comodel)
|
||||
subDomain, ok := value.(Domain)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("any/not any: value must be a Domain, got %T", value)
|
||||
}
|
||||
|
||||
// Compile the sub-domain against the comodel
|
||||
subCompiler := &DomainCompiler{model: comodel, env: dc.env}
|
||||
subWhere, subParams, err := subCompiler.Compile(subDomain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("any/not any: compile subdomain: %w", err)
|
||||
}
|
||||
|
||||
// Rebase parameter indices: shift them by the current param count
|
||||
baseIdx := len(dc.params)
|
||||
dc.params = append(dc.params, subParams...)
|
||||
rebased := subWhere
|
||||
// Replace $N with $(N+baseIdx) in the sub-where clause
|
||||
for i := len(subParams); i >= 1; i-- {
|
||||
old := fmt.Sprintf("$%d", i)
|
||||
new := fmt.Sprintf("$%d", i+baseIdx)
|
||||
rebased = strings.ReplaceAll(rebased, old, new)
|
||||
}
|
||||
|
||||
// Determine the join condition based on field type
|
||||
var joinCond string
|
||||
switch f.Type {
|
||||
case TypeOne2many:
|
||||
// One2many: comodel has a FK pointing back to this model
|
||||
inverseField := f.InverseField
|
||||
if inverseField == "" {
|
||||
return "", fmt.Errorf("any/not any: One2many field %q has no InverseField", column)
|
||||
}
|
||||
inverseF := comodel.GetField(inverseField)
|
||||
if inverseF == nil {
|
||||
return "", fmt.Errorf("any/not any: inverse field %q not found on %s", inverseField, comodel.Name())
|
||||
}
|
||||
joinCond = fmt.Sprintf("%q.%q = %q.\"id\"", comodel.Table(), inverseF.Column(), dc.model.Table())
|
||||
|
||||
case TypeMany2many:
|
||||
// Many2many: use junction table
|
||||
relation := f.Relation
|
||||
if relation == "" {
|
||||
t1, t2 := dc.model.Table(), comodel.Table()
|
||||
if t1 > t2 {
|
||||
t1, t2 = t2, t1
|
||||
}
|
||||
relation = fmt.Sprintf("%s_%s_rel", t1, t2)
|
||||
}
|
||||
col1 := f.Column1
|
||||
if col1 == "" {
|
||||
col1 = dc.model.Table() + "_id"
|
||||
}
|
||||
col2 := f.Column2
|
||||
if col2 == "" {
|
||||
col2 = comodel.Table() + "_id"
|
||||
}
|
||||
joinCond = fmt.Sprintf(
|
||||
"%q.\"id\" IN (SELECT %q FROM %q WHERE %q = %q.\"id\")",
|
||||
comodel.Table(), col2, relation, col1, dc.model.Table(),
|
||||
)
|
||||
|
||||
case TypeMany2one:
|
||||
// Many2one: this model has the FK
|
||||
joinCond = fmt.Sprintf("%q.\"id\" = %q.%q", comodel.Table(), dc.model.Table(), f.Column())
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("any/not any: field %q is type %s, expected relational", column, f.Type)
|
||||
}
|
||||
|
||||
subJoins := subCompiler.JoinSQL()
|
||||
prefix := "EXISTS"
|
||||
if negate {
|
||||
prefix = "NOT EXISTS"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (SELECT 1 FROM %q%s WHERE %s AND %s)",
|
||||
prefix, comodel.Table(), subJoins, joinCond, rebased,
|
||||
), nil
|
||||
}
|
||||
|
||||
// toInt64Slice normalizes a value to []int64 for hierarchy operators.
|
||||
func toInt64Slice(value Value) []int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return []int64{v}
|
||||
case int:
|
||||
return []int64{int64(v)}
|
||||
case int32:
|
||||
return []int64{int64(v)}
|
||||
case float64:
|
||||
return []int64{int64(v)}
|
||||
case []int64:
|
||||
return v
|
||||
case []int:
|
||||
out := make([]int64, len(v))
|
||||
for i, x := range v {
|
||||
out[i] = int64(x)
|
||||
}
|
||||
return out
|
||||
case []interface{}:
|
||||
out := make([]int64, 0, len(v))
|
||||
for _, x := range v {
|
||||
switch n := x.(type) {
|
||||
case int64:
|
||||
out = append(out, n)
|
||||
case int:
|
||||
out = append(out, int64(n))
|
||||
case float64:
|
||||
out = append(out, int64(n))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeSlice converts typed slices to []interface{} for IN/NOT IN operators.
|
||||
func normalizeSlice(value Value) []interface{} {
|
||||
switch v := value.(type) {
|
||||
|
||||
473
pkg/orm/domain_parse.go
Normal file
473
pkg/orm/domain_parse.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// ParseDomainString parses a Python-style domain_force string into an orm.Domain.
|
||||
// Mirrors: odoo/addons/base/models/ir_rule.py safe_eval(domain_force, eval_context)
|
||||
//
|
||||
// Supported syntax:
|
||||
// - Tuples: ('field', 'operator', value)
|
||||
// - Logical operators: '&', '|', '!'
|
||||
// - Values: string literals, numbers, True/False, None, list literals, context variables
|
||||
// - Context variables: user.id, company_id, user.company_id, company_ids
|
||||
//
|
||||
// The env parameter provides runtime context for variable resolution.
|
||||
func ParseDomainString(s string, env *Environment) (Domain, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" || s == "[]" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
p := &domainParser{
|
||||
input: []rune(s),
|
||||
pos: 0,
|
||||
env: env,
|
||||
}
|
||||
|
||||
return p.parseDomain()
|
||||
}
|
||||
|
||||
// domainParser is a simple recursive-descent parser for Python domain expressions.
|
||||
type domainParser struct {
|
||||
input []rune
|
||||
pos int
|
||||
env *Environment
|
||||
}
|
||||
|
||||
func (p *domainParser) parseDomain() (Domain, error) {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if p.input[p.pos] != '[' {
|
||||
return nil, fmt.Errorf("domain_parse: expected '[' at position %d, got %c", p.pos, p.input[p.pos])
|
||||
}
|
||||
p.pos++ // consume '['
|
||||
|
||||
var nodes []DomainNode
|
||||
|
||||
for {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, fmt.Errorf("domain_parse: unexpected end of input")
|
||||
}
|
||||
if p.input[p.pos] == ']' {
|
||||
p.pos++ // consume ']'
|
||||
break
|
||||
}
|
||||
|
||||
// Skip commas between elements
|
||||
if p.input[p.pos] == ',' {
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
|
||||
node, err := p.parseNode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
|
||||
// Convert the list of nodes into a proper Domain.
|
||||
// Odoo domains are in prefix (Polish) notation:
|
||||
// ['&', (a), (b)] means a AND b
|
||||
// If no explicit operator prefix, Odoo implicitly ANDs consecutive leaves.
|
||||
return normalizeDomainNodes(nodes), nil
|
||||
}
|
||||
|
||||
// normalizeDomainNodes adds implicit '&' operators between consecutive leaf nodes
|
||||
// that don't have an explicit operator, mirroring Odoo's behavior.
|
||||
func normalizeDomainNodes(nodes []DomainNode) Domain {
|
||||
if len(nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the domain already has operators in prefix position.
|
||||
// If first node is an operator, assume the domain is already in Polish notation.
|
||||
if _, isOp := nodes[0].(Operator); isOp {
|
||||
return Domain(nodes)
|
||||
}
|
||||
|
||||
// No prefix operators: implicitly AND all leaf conditions.
|
||||
if len(nodes) == 1 {
|
||||
return Domain{nodes[0]}
|
||||
}
|
||||
|
||||
// Multiple leaves without operators: AND them together.
|
||||
return And(nodes...)
|
||||
}
|
||||
|
||||
func (p *domainParser) parseNode() (DomainNode, error) {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, fmt.Errorf("domain_parse: unexpected end of input")
|
||||
}
|
||||
|
||||
ch := p.input[p.pos]
|
||||
|
||||
// Check for logical operators: '&', '|', '!'
|
||||
if ch == '\'' || ch == '"' {
|
||||
// Could be a string operator like '&' or '|' or '!'
|
||||
str, err := p.parseString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch str {
|
||||
case "&":
|
||||
return OpAnd, nil
|
||||
case "|":
|
||||
return OpOr, nil
|
||||
case "!":
|
||||
return OpNot, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("domain_parse: unexpected string %q where operator or tuple expected", str)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for tuple: (field, operator, value)
|
||||
if ch == '(' {
|
||||
return p.parseTuple()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("domain_parse: unexpected character %c at position %d", ch, p.pos)
|
||||
}
|
||||
|
||||
func (p *domainParser) parseTuple() (DomainNode, error) {
|
||||
if p.input[p.pos] != '(' {
|
||||
return nil, fmt.Errorf("domain_parse: expected '(' at position %d", p.pos)
|
||||
}
|
||||
p.pos++ // consume '('
|
||||
|
||||
// Parse field name (string)
|
||||
p.skipWhitespace()
|
||||
field, err := p.parseString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("domain_parse: field name: %w", err)
|
||||
}
|
||||
|
||||
p.skipWhitespace()
|
||||
p.expectChar(',')
|
||||
|
||||
// Parse operator (string)
|
||||
p.skipWhitespace()
|
||||
operator, err := p.parseString()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("domain_parse: operator: %w", err)
|
||||
}
|
||||
|
||||
p.skipWhitespace()
|
||||
p.expectChar(',')
|
||||
|
||||
// Parse value
|
||||
p.skipWhitespace()
|
||||
value, err := p.parseValue()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("domain_parse: value for (%s, %s, ...): %w", field, operator, err)
|
||||
}
|
||||
|
||||
p.skipWhitespace()
|
||||
if p.pos < len(p.input) && p.input[p.pos] == ')' {
|
||||
p.pos++ // consume ')'
|
||||
} else {
|
||||
return nil, fmt.Errorf("domain_parse: expected ')' at position %d", p.pos)
|
||||
}
|
||||
|
||||
return Condition{Field: field, Operator: operator, Value: value}, nil
|
||||
}
|
||||
|
||||
func (p *domainParser) parseValue() (Value, error) {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, fmt.Errorf("domain_parse: unexpected end of input in value")
|
||||
}
|
||||
|
||||
ch := p.input[p.pos]
|
||||
|
||||
// String literal
|
||||
if ch == '\'' || ch == '"' {
|
||||
return p.parseString()
|
||||
}
|
||||
|
||||
// List literal
|
||||
if ch == '[' {
|
||||
return p.parseList()
|
||||
}
|
||||
|
||||
// Tuple literal used as list value (some domain_force uses tuple syntax)
|
||||
if ch == '(' {
|
||||
return p.parseTupleAsList()
|
||||
}
|
||||
|
||||
// Number or negative number
|
||||
if ch == '-' || (ch >= '0' && ch <= '9') {
|
||||
return p.parseNumber()
|
||||
}
|
||||
|
||||
// Keywords or context variables
|
||||
if unicode.IsLetter(ch) || ch == '_' {
|
||||
return p.parseIdentOrKeyword()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("domain_parse: unexpected character %c at position %d in value", ch, p.pos)
|
||||
}
|
||||
|
||||
func (p *domainParser) parseString() (string, error) {
|
||||
if p.pos >= len(p.input) {
|
||||
return "", fmt.Errorf("domain_parse: unexpected end of input in string")
|
||||
}
|
||||
|
||||
quote := p.input[p.pos]
|
||||
if quote != '\'' && quote != '"' {
|
||||
return "", fmt.Errorf("domain_parse: expected quote at position %d, got %c", p.pos, quote)
|
||||
}
|
||||
p.pos++ // consume opening quote
|
||||
|
||||
var sb strings.Builder
|
||||
for p.pos < len(p.input) {
|
||||
ch := p.input[p.pos]
|
||||
if ch == '\\' && p.pos+1 < len(p.input) {
|
||||
p.pos++
|
||||
sb.WriteRune(p.input[p.pos])
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
p.pos++ // consume closing quote
|
||||
return sb.String(), nil
|
||||
}
|
||||
sb.WriteRune(ch)
|
||||
p.pos++
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("domain_parse: unterminated string starting at position %d", p.pos)
|
||||
}
|
||||
|
||||
func (p *domainParser) parseNumber() (Value, error) {
|
||||
start := p.pos
|
||||
if p.input[p.pos] == '-' {
|
||||
p.pos++
|
||||
}
|
||||
|
||||
isFloat := false
|
||||
for p.pos < len(p.input) {
|
||||
ch := p.input[p.pos]
|
||||
if ch == '.' && !isFloat {
|
||||
isFloat = true
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
if ch >= '0' && ch <= '9' {
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
numStr := string(p.input[start:p.pos])
|
||||
|
||||
if isFloat {
|
||||
f, err := strconv.ParseFloat(numStr, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("domain_parse: invalid float %q: %w", numStr, err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
n, err := strconv.ParseInt(numStr, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("domain_parse: invalid integer %q: %w", numStr, err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (p *domainParser) parseList() (Value, error) {
|
||||
if p.input[p.pos] != '[' {
|
||||
return nil, fmt.Errorf("domain_parse: expected '[' at position %d", p.pos)
|
||||
}
|
||||
p.pos++ // consume '['
|
||||
|
||||
var items []interface{}
|
||||
for {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, fmt.Errorf("domain_parse: unterminated list")
|
||||
}
|
||||
if p.input[p.pos] == ']' {
|
||||
p.pos++ // consume ']'
|
||||
break
|
||||
}
|
||||
if p.input[p.pos] == ',' {
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := p.parseValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, val)
|
||||
}
|
||||
|
||||
// Try to produce typed slices for common cases.
|
||||
return normalizeListValue(items), nil
|
||||
}
|
||||
|
||||
// parseTupleAsList parses a Python tuple literal (1, 2, 3) as a list value.
|
||||
func (p *domainParser) parseTupleAsList() (Value, error) {
|
||||
if p.input[p.pos] != '(' {
|
||||
return nil, fmt.Errorf("domain_parse: expected '(' at position %d", p.pos)
|
||||
}
|
||||
p.pos++ // consume '('
|
||||
|
||||
var items []interface{}
|
||||
for {
|
||||
p.skipWhitespace()
|
||||
if p.pos >= len(p.input) {
|
||||
return nil, fmt.Errorf("domain_parse: unterminated tuple-as-list")
|
||||
}
|
||||
if p.input[p.pos] == ')' {
|
||||
p.pos++ // consume ')'
|
||||
break
|
||||
}
|
||||
if p.input[p.pos] == ',' {
|
||||
p.pos++
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := p.parseValue()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, val)
|
||||
}
|
||||
|
||||
return normalizeListValue(items), nil
|
||||
}
|
||||
|
||||
// normalizeListValue converts []interface{} to typed slices when all elements
|
||||
// share the same type, for compatibility with normalizeSlice in domain compilation.
|
||||
func normalizeListValue(items []interface{}) interface{} {
|
||||
if len(items) == 0 {
|
||||
return []int64{}
|
||||
}
|
||||
|
||||
// Check if all items are int64
|
||||
allInt := true
|
||||
for _, v := range items {
|
||||
if _, ok := v.(int64); !ok {
|
||||
allInt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allInt {
|
||||
result := make([]int64, len(items))
|
||||
for i, v := range items {
|
||||
result[i] = v.(int64)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Check if all items are strings
|
||||
allStr := true
|
||||
for _, v := range items {
|
||||
if _, ok := v.(string); !ok {
|
||||
allStr = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allStr {
|
||||
result := make([]string, len(items))
|
||||
for i, v := range items {
|
||||
result[i] = v.(string)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
|
||||
func (p *domainParser) parseIdentOrKeyword() (Value, error) {
|
||||
start := p.pos
|
||||
for p.pos < len(p.input) {
|
||||
ch := p.input[p.pos]
|
||||
if unicode.IsLetter(ch) || unicode.IsDigit(ch) || ch == '_' || ch == '.' {
|
||||
p.pos++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ident := string(p.input[start:p.pos])
|
||||
|
||||
switch ident {
|
||||
case "True":
|
||||
return true, nil
|
||||
case "False":
|
||||
return false, nil
|
||||
case "None":
|
||||
return nil, nil
|
||||
|
||||
// Context variables from _eval_context
|
||||
case "user.id":
|
||||
if p.env != nil {
|
||||
return p.env.UID(), nil
|
||||
}
|
||||
return int64(0), nil
|
||||
|
||||
case "company_id", "user.company_id":
|
||||
if p.env != nil {
|
||||
return p.env.CompanyID(), nil
|
||||
}
|
||||
return int64(0), nil
|
||||
|
||||
case "company_ids":
|
||||
if p.env != nil {
|
||||
return []int64{p.env.CompanyID()}, nil
|
||||
}
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
// Handle dotted identifiers that start with known prefixes.
|
||||
// e.g., user.company_id.id, user.partner_id.id, etc.
|
||||
if strings.HasPrefix(ident, "user.") {
|
||||
// For now, resolve common patterns. Unknown paths return 0/nil.
|
||||
switch ident {
|
||||
case "user.company_id.id":
|
||||
if p.env != nil {
|
||||
return p.env.CompanyID(), nil
|
||||
}
|
||||
return int64(0), nil
|
||||
case "user.company_ids.ids":
|
||||
if p.env != nil {
|
||||
return []int64{p.env.CompanyID()}, nil
|
||||
}
|
||||
return []int64{}, nil
|
||||
default:
|
||||
// Unknown user attribute: return 0 as safe fallback.
|
||||
return int64(0), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("domain_parse: unknown identifier %q at position %d", ident, start)
|
||||
}
|
||||
|
||||
func (p *domainParser) skipWhitespace() {
|
||||
for p.pos < len(p.input) && unicode.IsSpace(p.input[p.pos]) {
|
||||
p.pos++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *domainParser) expectChar(ch rune) {
|
||||
p.skipWhitespace()
|
||||
if p.pos < len(p.input) && p.input[p.pos] == ch {
|
||||
p.pos++
|
||||
}
|
||||
// Tolerate missing comma (lenient parsing)
|
||||
}
|
||||
422
pkg/orm/read_group.go
Normal file
422
pkg/orm/read_group.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ReadGroupResult holds one group returned by ReadGroup.
|
||||
// Mirrors: one row from odoo/orm/models.py _read_group() result tuples.
|
||||
type ReadGroupResult struct {
|
||||
// GroupValues maps groupby spec → grouped value (e.g., "state" → "draft")
|
||||
GroupValues map[string]interface{}
|
||||
// AggValues maps aggregate spec → aggregated value (e.g., "amount_total:sum" → 1234.56)
|
||||
AggValues map[string]interface{}
|
||||
// Domain is the filter domain that selects records in this group.
|
||||
Domain []interface{}
|
||||
// Count is the number of records in this group (__count).
|
||||
Count int64
|
||||
}
|
||||
|
||||
// readGroupbyCol describes a parsed groupby column for ReadGroup.
|
||||
type readGroupbyCol struct {
|
||||
spec string // original spec, e.g. "date_order:month"
|
||||
fieldName string // field name, e.g. "date_order"
|
||||
granularity string // e.g. "month", "" if none
|
||||
sqlExpr string // SQL expression for SELECT and GROUP BY
|
||||
field *Field
|
||||
}
|
||||
|
||||
// ReadGroupOpts configures a ReadGroup call.
|
||||
type ReadGroupOpts struct {
|
||||
Offset int
|
||||
Limit int
|
||||
Order string
|
||||
}
|
||||
|
||||
// ReadGroup performs a grouped aggregation query.
|
||||
// Mirrors: odoo/orm/models.py BaseModel._read_group()
|
||||
//
|
||||
// groupby: list of groupby specs, e.g. ["state", "date_order:month", "partner_id"]
|
||||
// aggregates: list of aggregate specs, e.g. ["__count", "amount_total:sum", "id:count_distinct"]
|
||||
func (rs *Recordset) ReadGroup(domain Domain, groupby []string, aggregates []string, opts ...ReadGroupOpts) ([]ReadGroupResult, error) {
|
||||
m := rs.model
|
||||
opt := ReadGroupOpts{}
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
// Apply record rules
|
||||
domain = ApplyRecordRules(rs.env, m, domain)
|
||||
|
||||
// Compile domain to WHERE clause
|
||||
compiler := &DomainCompiler{model: m, env: rs.env}
|
||||
where, params, err := compiler.Compile(domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("orm: read_group %s: %w", m.name, err)
|
||||
}
|
||||
|
||||
// Parse groupby specs
|
||||
var gbCols []readGroupbyCol
|
||||
|
||||
for _, spec := range groupby {
|
||||
fieldName, granularity := parseGroupbySpec(spec)
|
||||
f := m.GetField(fieldName)
|
||||
if f == nil {
|
||||
return nil, fmt.Errorf("orm: read_group: field %q not found on %s", fieldName, m.name)
|
||||
}
|
||||
|
||||
sqlExpr := groupbySQLExpr(m.table, f, granularity)
|
||||
gbCols = append(gbCols, readGroupbyCol{
|
||||
spec: spec,
|
||||
fieldName: fieldName,
|
||||
granularity: granularity,
|
||||
sqlExpr: sqlExpr,
|
||||
field: f,
|
||||
})
|
||||
}
|
||||
|
||||
// Parse aggregate specs
|
||||
type aggCol struct {
|
||||
spec string // original spec, e.g. "amount_total:sum"
|
||||
fieldName string
|
||||
function string // e.g. "sum", "count", "avg"
|
||||
sqlExpr string
|
||||
}
|
||||
var aggCols []aggCol
|
||||
|
||||
for _, spec := range aggregates {
|
||||
if spec == "__count" {
|
||||
aggCols = append(aggCols, aggCol{
|
||||
spec: "__count",
|
||||
sqlExpr: "COUNT(*)",
|
||||
})
|
||||
continue
|
||||
}
|
||||
fieldName, function := parseAggregateSpec(spec)
|
||||
if function == "" {
|
||||
return nil, fmt.Errorf("orm: read_group: aggregate %q missing function (expected field:func)", spec)
|
||||
}
|
||||
f := m.GetField(fieldName)
|
||||
if f == nil {
|
||||
return nil, fmt.Errorf("orm: read_group: field %q not found on %s", fieldName, m.name)
|
||||
}
|
||||
sqlFunc := aggregateSQLFunc(function, fmt.Sprintf("%q.%q", m.table, f.Column()))
|
||||
if sqlFunc == "" {
|
||||
return nil, fmt.Errorf("orm: read_group: unknown aggregate function %q", function)
|
||||
}
|
||||
aggCols = append(aggCols, aggCol{
|
||||
spec: spec,
|
||||
fieldName: fieldName,
|
||||
function: function,
|
||||
sqlExpr: sqlFunc,
|
||||
})
|
||||
}
|
||||
|
||||
// Build SELECT clause
|
||||
var selectParts []string
|
||||
for _, gb := range gbCols {
|
||||
selectParts = append(selectParts, gb.sqlExpr)
|
||||
}
|
||||
for _, agg := range aggCols {
|
||||
selectParts = append(selectParts, agg.sqlExpr)
|
||||
}
|
||||
if len(selectParts) == 0 {
|
||||
selectParts = append(selectParts, "COUNT(*)")
|
||||
}
|
||||
|
||||
// Build GROUP BY clause
|
||||
var groupByParts []string
|
||||
for _, gb := range gbCols {
|
||||
groupByParts = append(groupByParts, gb.sqlExpr)
|
||||
}
|
||||
|
||||
// Build ORDER BY
|
||||
orderSQL := ""
|
||||
if opt.Order != "" {
|
||||
orderSQL = opt.Order
|
||||
} else if len(gbCols) > 0 {
|
||||
// Default: order by groupby columns
|
||||
var orderParts []string
|
||||
for _, gb := range gbCols {
|
||||
orderParts = append(orderParts, gb.sqlExpr)
|
||||
}
|
||||
orderSQL = strings.Join(orderParts, ", ")
|
||||
}
|
||||
|
||||
// Assemble query
|
||||
joinSQL := compiler.JoinSQL()
|
||||
query := fmt.Sprintf("SELECT %s FROM %q%s WHERE %s",
|
||||
strings.Join(selectParts, ", "),
|
||||
m.table,
|
||||
joinSQL,
|
||||
where,
|
||||
)
|
||||
if len(groupByParts) > 0 {
|
||||
query += " GROUP BY " + strings.Join(groupByParts, ", ")
|
||||
}
|
||||
if orderSQL != "" {
|
||||
query += " ORDER BY " + orderSQL
|
||||
}
|
||||
if opt.Limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT %d", opt.Limit)
|
||||
}
|
||||
if opt.Offset > 0 {
|
||||
query += fmt.Sprintf(" OFFSET %d", opt.Offset)
|
||||
}
|
||||
|
||||
// Execute
|
||||
rows, err := rs.env.tx.Query(rs.env.ctx, query, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("orm: read_group %s: %w", m.name, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Scan results
|
||||
totalCols := len(gbCols) + len(aggCols)
|
||||
if totalCols == 0 {
|
||||
totalCols = 1 // COUNT(*) fallback
|
||||
}
|
||||
|
||||
var results []ReadGroupResult
|
||||
for rows.Next() {
|
||||
scanDest := make([]interface{}, totalCols)
|
||||
for i := range scanDest {
|
||||
scanDest[i] = new(interface{})
|
||||
}
|
||||
if err := rows.Scan(scanDest...); err != nil {
|
||||
return nil, fmt.Errorf("orm: read_group scan %s: %w", m.name, err)
|
||||
}
|
||||
|
||||
result := ReadGroupResult{
|
||||
GroupValues: make(map[string]interface{}),
|
||||
AggValues: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Extract groupby values
|
||||
for i, gb := range gbCols {
|
||||
val := *(scanDest[i].(*interface{}))
|
||||
result.GroupValues[gb.spec] = val
|
||||
}
|
||||
|
||||
// Extract aggregate values
|
||||
for i, agg := range aggCols {
|
||||
val := *(scanDest[len(gbCols)+i].(*interface{}))
|
||||
if agg.spec == "__count" {
|
||||
result.Count = asInt64(val)
|
||||
result.AggValues["__count"] = result.Count
|
||||
} else {
|
||||
result.AggValues[agg.spec] = val
|
||||
}
|
||||
}
|
||||
|
||||
// If __count not explicitly requested, add it from COUNT(*)
|
||||
if _, hasCount := result.AggValues["__count"]; !hasCount {
|
||||
result.Count = 0
|
||||
}
|
||||
|
||||
// Build domain for this group
|
||||
result.Domain = buildGroupDomain(gbCols, scanDest)
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("orm: read_group %s: %w", m.name, err)
|
||||
}
|
||||
|
||||
// Post-process: resolve Many2one groupby values to [id, display_name]
|
||||
for _, gb := range gbCols {
|
||||
if gb.field.Type == TypeMany2one && gb.field.Comodel != "" {
|
||||
if err := rs.resolveM2OGroupby(gb.spec, gb.field, results); err != nil {
|
||||
// Non-fatal: log and continue with raw IDs
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// resolveM2OGroupby replaces raw FK IDs in group results with [id, display_name] pairs.
|
||||
func (rs *Recordset) resolveM2OGroupby(spec string, f *Field, results []ReadGroupResult) error {
|
||||
// Collect unique IDs
|
||||
idSet := make(map[int64]bool)
|
||||
for _, r := range results {
|
||||
if id := asInt64(r.GroupValues[spec]); id > 0 {
|
||||
idSet[id] = true
|
||||
}
|
||||
}
|
||||
if len(idSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ids []int64
|
||||
for id := range idSet {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
// Fetch display names
|
||||
comodelRS := rs.env.Model(f.Comodel).Browse(ids...)
|
||||
names, err := comodelRS.NameGet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Replace values
|
||||
for i, r := range results {
|
||||
rawID := asInt64(r.GroupValues[spec])
|
||||
if rawID > 0 {
|
||||
name := names[rawID]
|
||||
results[i].GroupValues[spec] = []interface{}{rawID, name}
|
||||
} else {
|
||||
results[i].GroupValues[spec] = false
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseGroupbySpec splits "field:granularity" into field name and granularity.
|
||||
// Mirrors: odoo/orm/models.py parse_read_group_spec() for groupby
|
||||
func parseGroupbySpec(spec string) (fieldName, granularity string) {
|
||||
parts := strings.SplitN(spec, ":", 2)
|
||||
fieldName = parts[0]
|
||||
if len(parts) > 1 {
|
||||
granularity = parts[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// parseAggregateSpec splits "field:function" into field name and aggregate function.
|
||||
// Mirrors: odoo/orm/models.py parse_read_group_spec() for aggregates
|
||||
func parseAggregateSpec(spec string) (fieldName, function string) {
|
||||
parts := strings.SplitN(spec, ":", 2)
|
||||
fieldName = parts[0]
|
||||
if len(parts) > 1 {
|
||||
function = parts[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// groupbySQLExpr returns the SQL expression for a GROUP BY column.
|
||||
// Mirrors: odoo/orm/models.py _read_group_groupby()
|
||||
func groupbySQLExpr(table string, f *Field, granularity string) string {
|
||||
col := fmt.Sprintf("%q.%q", table, f.Column())
|
||||
|
||||
if granularity == "" {
|
||||
// Boolean fields: COALESCE to false (like Python Odoo)
|
||||
if f.Type == TypeBoolean {
|
||||
return fmt.Sprintf("COALESCE(%s, FALSE)", col)
|
||||
}
|
||||
return col
|
||||
}
|
||||
|
||||
// Date/datetime granularity
|
||||
// Mirrors: odoo/orm/models.py _read_group_groupby() date_trunc branch
|
||||
switch granularity {
|
||||
case "day", "month", "quarter", "year":
|
||||
expr := fmt.Sprintf("date_trunc('%s', %s::timestamp)", granularity, col)
|
||||
if f.Type == TypeDate {
|
||||
expr += "::date"
|
||||
}
|
||||
return expr
|
||||
case "week":
|
||||
// ISO week: truncate to Monday
|
||||
expr := fmt.Sprintf("date_trunc('week', %s::timestamp)", col)
|
||||
if f.Type == TypeDate {
|
||||
expr += "::date"
|
||||
}
|
||||
return expr
|
||||
case "year_number":
|
||||
return fmt.Sprintf("EXTRACT(YEAR FROM %s)", col)
|
||||
case "quarter_number":
|
||||
return fmt.Sprintf("EXTRACT(QUARTER FROM %s)", col)
|
||||
case "month_number":
|
||||
return fmt.Sprintf("EXTRACT(MONTH FROM %s)", col)
|
||||
case "iso_week_number":
|
||||
return fmt.Sprintf("EXTRACT(WEEK FROM %s)", col)
|
||||
case "day_of_year":
|
||||
return fmt.Sprintf("EXTRACT(DOY FROM %s)", col)
|
||||
case "day_of_month":
|
||||
return fmt.Sprintf("EXTRACT(DAY FROM %s)", col)
|
||||
case "day_of_week":
|
||||
return fmt.Sprintf("EXTRACT(ISODOW FROM %s)", col)
|
||||
case "hour_number":
|
||||
return fmt.Sprintf("EXTRACT(HOUR FROM %s)", col)
|
||||
case "minute_number":
|
||||
return fmt.Sprintf("EXTRACT(MINUTE FROM %s)", col)
|
||||
case "second_number":
|
||||
return fmt.Sprintf("EXTRACT(SECOND FROM %s)", col)
|
||||
default:
|
||||
// Unknown granularity: fall back to plain column
|
||||
return col
|
||||
}
|
||||
}
|
||||
|
||||
// aggregateSQLFunc returns the SQL aggregate expression.
|
||||
// Mirrors: odoo/orm/models.py READ_GROUP_AGGREGATE
|
||||
func aggregateSQLFunc(function, column string) string {
|
||||
switch function {
|
||||
case "sum":
|
||||
return fmt.Sprintf("SUM(%s)", column)
|
||||
case "avg":
|
||||
return fmt.Sprintf("AVG(%s)", column)
|
||||
case "max":
|
||||
return fmt.Sprintf("MAX(%s)", column)
|
||||
case "min":
|
||||
return fmt.Sprintf("MIN(%s)", column)
|
||||
case "count":
|
||||
return fmt.Sprintf("COUNT(%s)", column)
|
||||
case "count_distinct":
|
||||
return fmt.Sprintf("COUNT(DISTINCT %s)", column)
|
||||
case "bool_and":
|
||||
return fmt.Sprintf("BOOL_AND(%s)", column)
|
||||
case "bool_or":
|
||||
return fmt.Sprintf("BOOL_OR(%s)", column)
|
||||
case "array_agg":
|
||||
return fmt.Sprintf("ARRAY_AGG(%s)", column)
|
||||
case "array_agg_distinct":
|
||||
return fmt.Sprintf("ARRAY_AGG(DISTINCT %s)", column)
|
||||
case "recordset":
|
||||
return fmt.Sprintf("ARRAY_AGG(%s)", column)
|
||||
case "sum_currency":
|
||||
// Simplified: SUM without currency conversion (full impl needs exchange rates)
|
||||
return fmt.Sprintf("SUM(%s)", column)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// buildGroupDomain builds a domain that selects all records in this group.
|
||||
func buildGroupDomain(gbCols []readGroupbyCol, scanDest []interface{}) []interface{} {
|
||||
var domain []interface{}
|
||||
for i, gb := range gbCols {
|
||||
val := *(scanDest[i].(*interface{}))
|
||||
if val == nil {
|
||||
domain = append(domain, []interface{}{gb.fieldName, "=", false})
|
||||
} else if gb.granularity != "" && isTimeGranularity(gb.granularity) {
|
||||
// For date grouping, build a range domain
|
||||
// The raw value is the truncated date — client uses __range instead
|
||||
domain = append(domain, []interface{}{gb.fieldName, "=", val})
|
||||
} else {
|
||||
domain = append(domain, []interface{}{gb.fieldName, "=", val})
|
||||
}
|
||||
}
|
||||
return domain
|
||||
}
|
||||
|
||||
// isTimeGranularity returns true for date/time truncation granularities.
|
||||
func isTimeGranularity(g string) bool {
|
||||
switch g {
|
||||
case "day", "week", "month", "quarter", "year":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// asInt64 converts various numeric types to int64 (ignoring ok).
|
||||
// Uses toInt64 from relational.go when bool result is needed.
|
||||
func asInt64(v interface{}) int64 {
|
||||
n, _ := toInt64(v)
|
||||
return n
|
||||
}
|
||||
@@ -140,6 +140,12 @@ func (rs *Recordset) Create(vals Values) (*Recordset, error) {
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
// Odoo sends false for empty fields; convert to nil for non-boolean types
|
||||
val = sanitizeFieldValue(f, val)
|
||||
// Skip nil values (let DB use column default)
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
columns = append(columns, fmt.Sprintf("%q", f.Column()))
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", idx))
|
||||
args = append(args, val)
|
||||
@@ -239,6 +245,9 @@ func (rs *Recordset) Write(vals Values) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// Odoo sends false for empty fields; convert to nil for non-boolean types
|
||||
val = sanitizeFieldValue(f, val)
|
||||
|
||||
setClauses = append(setClauses, fmt.Sprintf("%q = $%d", f.Column(), idx))
|
||||
args = append(args, val)
|
||||
idx++
|
||||
@@ -585,7 +594,7 @@ func (rs *Recordset) Search(domain Domain, opts ...SearchOpts) (*Recordset, erro
|
||||
domain = ApplyRecordRules(rs.env, m, domain)
|
||||
|
||||
// Compile domain to SQL
|
||||
compiler := &DomainCompiler{model: m}
|
||||
compiler := &DomainCompiler{model: m, env: rs.env}
|
||||
where, params, err := compiler.Compile(domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("orm: search %s: %w", m.name, err)
|
||||
@@ -638,7 +647,7 @@ func (rs *Recordset) Search(domain Domain, opts ...SearchOpts) (*Recordset, erro
|
||||
func (rs *Recordset) SearchCount(domain Domain) (int64, error) {
|
||||
m := rs.model
|
||||
|
||||
compiler := &DomainCompiler{model: m}
|
||||
compiler := &DomainCompiler{model: m, env: rs.env}
|
||||
where, params, err := compiler.Compile(domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("orm: search_count %s: %w", m.name, err)
|
||||
@@ -859,3 +868,31 @@ func processRelationalCommands(env *Environment, m *Model, parentID int64, vals
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sanitizeFieldValue converts Odoo's false/empty values to Go-native types
|
||||
// suitable for PostgreSQL. Odoo sends false for empty string/numeric/relational
|
||||
// fields; PostgreSQL rejects false for varchar/int columns.
|
||||
// Mirrors: odoo/orm/fields.py convert_to_column()
|
||||
func sanitizeFieldValue(f *Field, val interface{}) interface{} {
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle the Odoo false → nil conversion for non-boolean fields
|
||||
if b, ok := val.(bool); ok && !b {
|
||||
if f.Type == TypeBoolean {
|
||||
return false // Keep false for boolean fields
|
||||
}
|
||||
return nil // Convert false → NULL for all other types
|
||||
}
|
||||
|
||||
// Handle float→int conversion for integer/M2O fields
|
||||
switch f.Type {
|
||||
case TypeInteger, TypeMany2one:
|
||||
if fv, ok := val.(float64); ok {
|
||||
return int64(fv)
|
||||
}
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
@@ -275,6 +275,8 @@ func toInt64(v interface{}) (int64, bool) {
|
||||
return int64(n), true
|
||||
case int64:
|
||||
return n, true
|
||||
case int32:
|
||||
return int64(n), true
|
||||
case int:
|
||||
return int64(n), true
|
||||
}
|
||||
|
||||
148
pkg/orm/rules.go
148
pkg/orm/rules.go
@@ -10,12 +10,12 @@ import (
|
||||
//
|
||||
// Rules work as follows:
|
||||
// - Global rules (no groups) are AND-ed together
|
||||
// - Group rules are OR-ed within the group set
|
||||
// - The final domain is: global_rules AND (group_rule_1 OR group_rule_2 OR ...)
|
||||
// - Group rules (user belongs to one of the rule's groups) are OR-ed together
|
||||
// - The final domain is: original AND global_rules AND (group_rule_1 OR group_rule_2 OR ...)
|
||||
//
|
||||
// Implementation:
|
||||
// 1. Built-in company filter (for models with company_id)
|
||||
// 2. Custom ir.rule records loaded from the database
|
||||
// 2. Custom ir.rule records loaded from the database, domain_force parsed
|
||||
func ApplyRecordRules(env *Environment, m *Model, domain Domain) Domain {
|
||||
if env.su {
|
||||
return domain // Superuser bypasses record rules
|
||||
@@ -38,59 +38,143 @@ func ApplyRecordRules(env *Environment, m *Model, domain Domain) Domain {
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Load custom ir.rule records from DB
|
||||
// Mirrors: odoo/addons/base/models/ir_rule.py IrRule._compute_domain()
|
||||
// 2. Load ir.rule records from DB
|
||||
// Mirrors: odoo/addons/base/models/ir_rule.py IrRule._get_rules() + _compute_domain()
|
||||
//
|
||||
// Query rules that apply to this model for the current user:
|
||||
// - Rule must be active and have perm_read = true
|
||||
// - Either the rule has no group restriction (global rule),
|
||||
// or the user belongs to one of the rule's groups.
|
||||
// Use a savepoint so that a failed query (e.g., missing junction table)
|
||||
// doesn't abort the parent transaction.
|
||||
// - Either the rule is global (no groups assigned),
|
||||
// or the user belongs to one of the rule's groups via rule_group_rel.
|
||||
// Use a savepoint so that a failed query (e.g., missing table) doesn't abort the parent tx.
|
||||
sp, spErr := env.tx.Begin(env.ctx)
|
||||
if spErr != nil {
|
||||
return domain
|
||||
}
|
||||
|
||||
rows, err := sp.Query(env.ctx,
|
||||
`SELECT r.id, r.domain_force, COALESCE(r.global, false)
|
||||
`SELECT r.id, r.domain_force, COALESCE(r."global", false) AS is_global
|
||||
FROM ir_rule r
|
||||
JOIN ir_model m ON m.id = r.model_id
|
||||
WHERE m.model = $1 AND r.active = true
|
||||
AND r.perm_read = true`,
|
||||
m.Name())
|
||||
WHERE m.model = $1
|
||||
AND r.active = true
|
||||
AND r.perm_read = true
|
||||
AND (
|
||||
r."global" = true
|
||||
OR r.id IN (
|
||||
SELECT rg.rule_group_id
|
||||
FROM rule_group_rel rg
|
||||
JOIN res_groups_users_rel gu ON gu.gid = rg.group_id
|
||||
WHERE gu.uid = $2
|
||||
)
|
||||
)
|
||||
ORDER BY r.id`,
|
||||
m.Name(), env.UID())
|
||||
if err != nil {
|
||||
sp.Rollback(env.ctx)
|
||||
return domain
|
||||
}
|
||||
defer func() {
|
||||
rows.Close()
|
||||
sp.Commit(env.ctx)
|
||||
}()
|
||||
|
||||
// Collect domain_force strings from matching rules
|
||||
// TODO: parse domain_force strings into Domain objects and merge them
|
||||
ruleCount := 0
|
||||
type ruleRow struct {
|
||||
id int64
|
||||
domainForce *string
|
||||
global bool
|
||||
}
|
||||
var rules []ruleRow
|
||||
|
||||
for rows.Next() {
|
||||
var ruleID int64
|
||||
var domainForce *string
|
||||
var global bool
|
||||
if err := rows.Scan(&ruleID, &domainForce, &global); err != nil {
|
||||
var r ruleRow
|
||||
if err := rows.Scan(&r.id, &r.domainForce, &r.global); err != nil {
|
||||
continue
|
||||
}
|
||||
ruleCount++
|
||||
// TODO: parse domainForce (Python-style domain string) into Domain
|
||||
// and AND global rules / OR group rules into the result domain.
|
||||
// For now, rules are loaded but domain parsing is deferred.
|
||||
_ = domainForce
|
||||
_ = global
|
||||
rules = append(rules, r)
|
||||
}
|
||||
if ruleCount > 0 {
|
||||
log.Printf("orm: loaded %d ir.rule record(s) for %s (domain parsing pending)", ruleCount, m.Name())
|
||||
rows.Close()
|
||||
if err := sp.Commit(env.ctx); err != nil {
|
||||
// Non-fatal: rules already read
|
||||
_ = err
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
return domain
|
||||
}
|
||||
|
||||
// Parse domain_force strings and split into global vs. group rules.
|
||||
// Mirrors: odoo/addons/base/models/ir_rule.py IrRule._compute_domain()
|
||||
// global rules → AND together
|
||||
// group rules → OR together
|
||||
// final = original AND all_global AND (group_1 OR group_2 OR ...)
|
||||
var globalDomains []DomainNode
|
||||
var groupDomains []DomainNode
|
||||
parseErrors := 0
|
||||
|
||||
for _, r := range rules {
|
||||
if r.domainForce == nil || *r.domainForce == "" || *r.domainForce == "[]" {
|
||||
// Empty domain_force = match everything, skip
|
||||
continue
|
||||
}
|
||||
|
||||
parsed, err := ParseDomainString(*r.domainForce, env)
|
||||
if err != nil {
|
||||
parseErrors++
|
||||
log.Printf("orm: failed to parse domain_force for ir.rule %d: %v (raw: %s)", r.id, err, *r.domainForce)
|
||||
continue
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if r.global {
|
||||
// Global rule: wrap as a single node for AND-ing
|
||||
globalDomains = append(globalDomains, domainAsNode(parsed))
|
||||
} else {
|
||||
// Group rule: wrap as a single node for OR-ing
|
||||
groupDomains = append(groupDomains, domainAsNode(parsed))
|
||||
}
|
||||
}
|
||||
|
||||
if parseErrors > 0 {
|
||||
log.Printf("orm: %d ir.rule domain_force parse error(s) for %s", parseErrors, m.Name())
|
||||
}
|
||||
|
||||
// Merge group domains with OR
|
||||
if len(groupDomains) > 0 {
|
||||
orDomain := Or(groupDomains...)
|
||||
globalDomains = append(globalDomains, domainAsNode(orDomain))
|
||||
}
|
||||
|
||||
// AND all rule domains into the original domain
|
||||
if len(globalDomains) > 0 {
|
||||
ruleDomain := And(globalDomains...)
|
||||
if len(domain) == 0 {
|
||||
domain = ruleDomain
|
||||
} else {
|
||||
result := Domain{OpAnd}
|
||||
result = append(result, domain...)
|
||||
result = append(result, ruleDomain...)
|
||||
domain = result
|
||||
}
|
||||
}
|
||||
|
||||
return domain
|
||||
}
|
||||
|
||||
// domainAsNode wraps a Domain (which is a []DomainNode) into a single DomainNode
|
||||
// so it can be used as an operand for And() / Or().
|
||||
// If the domain has a single node, return it directly.
|
||||
// If multiple nodes, wrap in a domainGroup.
|
||||
func domainAsNode(d Domain) DomainNode {
|
||||
if len(d) == 1 {
|
||||
return d[0]
|
||||
}
|
||||
return domainGroup(d)
|
||||
}
|
||||
|
||||
// domainGroup wraps a Domain as a single DomainNode for use in And()/Or() combinations.
|
||||
// When compiled, it produces the same SQL as the contained domain.
|
||||
type domainGroup Domain
|
||||
|
||||
func (dg domainGroup) isDomainNode() {}
|
||||
|
||||
// CheckRecordRuleAccess verifies the user can access specific record IDs.
|
||||
// Returns an error if any record is not accessible.
|
||||
func CheckRecordRuleAccess(env *Environment, m *Model, ids []int64, perm string) error {
|
||||
|
||||
Reference in New Issue
Block a user