feat: Portal, Email Inbound, Discuss + module improvements

- Portal: /my/* routes, signup, password reset, portal user support
- Email Inbound: IMAP polling (go-imap/v2), thread matching
- Discuss: mail.channel, long-polling bus, DM, unread count
- Cron: ir.cron runner (goroutine scheduler)
- Bank Import, CSV/Excel Import
- Automation (ir.actions.server)
- Fetchmail service
- HR Payroll model
- Various fixes across account, sale, stock, purchase, crm, hr, project

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Marc
2026-04-12 18:41:57 +02:00
parent 2c7c1e6c88
commit 66383adf06
87 changed files with 14696 additions and 654 deletions

View File

@@ -1,6 +1,9 @@
package orm
import "fmt"
import (
"fmt"
"log"
)
// ComputeFunc is a function that computes field values for a recordset.
// Mirrors: @api.depends decorated methods in Odoo.
@@ -253,7 +256,7 @@ func RunOnchangeComputes(m *Model, env *Environment, currentVals Values, changed
computed, err := fn(rs)
if err != nil {
// Non-fatal: skip failed computes during onchange
log.Printf("orm: onchange compute %s.%s failed: %v", m.Name(), fieldName, err)
continue
}
for k, v := range computed {

View File

@@ -2,6 +2,8 @@ package orm
import (
"fmt"
"regexp"
"strconv"
"strings"
)
@@ -152,6 +154,8 @@ func (dc *DomainCompiler) JoinSQL() string {
return " " + strings.Join(parts, " ")
}
// compileNodes compiles domain nodes in Polish (prefix) notation.
// Returns the SQL string and the number of nodes consumed from the domain starting at pos.
func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
if pos >= len(domain) {
return "TRUE", nil
@@ -167,7 +171,8 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
if err != nil {
return "", err
}
right, err := dc.compileNodes(domain, pos+2)
leftSize := nodeSize(domain, pos+1)
right, err := dc.compileNodes(domain, pos+1+leftSize)
if err != nil {
return "", err
}
@@ -178,7 +183,8 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
if err != nil {
return "", err
}
right, err := dc.compileNodes(domain, pos+2)
leftSize := nodeSize(domain, pos+1)
right, err := dc.compileNodes(domain, pos+1+leftSize)
if err != nil {
return "", err
}
@@ -196,8 +202,6 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
return dc.compileCondition(n)
case domainGroup:
// domainGroup wraps a sub-domain as a single node.
// Compile it recursively as a full domain.
subSQL, _, err := dc.compileDomainGroup(Domain(n))
if err != nil {
return "", err
@@ -208,6 +212,28 @@ func (dc *DomainCompiler) compileNodes(domain Domain, pos int) (string, error) {
return "", fmt.Errorf("unexpected domain node at position %d: %v", pos, node)
}
// nodeSize returns the number of domain nodes consumed by the subtree at pos.
// Operators (&, |) consume 1 + left subtree + right subtree.
// NOT consumes 1 + inner subtree. Leaf nodes consume 1.
func nodeSize(domain Domain, pos int) int {
if pos >= len(domain) {
return 0
}
switch n := domain[pos].(type) {
case Operator:
_ = n
switch domain[pos].(Operator) {
case OpAnd, OpOr:
leftSize := nodeSize(domain, pos+1)
rightSize := nodeSize(domain, pos+1+leftSize)
return 1 + leftSize + rightSize
case OpNot:
return 1 + nodeSize(domain, pos+1)
}
}
return 1 // Condition or domainGroup = 1 node
}
// compileDomainGroup compiles a sub-domain that was wrapped via domainGroup.
// It reuses the same DomainCompiler (sharing params and joins) so parameter
// indices stay consistent with the outer query.
@@ -227,14 +253,12 @@ func (dc *DomainCompiler) compileCondition(c Condition) (string, error) {
return "", fmt.Errorf("invalid operator: %q", c.Operator)
}
// Handle dot notation (e.g., "partner_id.name")
// Handle dot notation (e.g., "partner_id.name", "partner_id.country_id.code")
// by generating LEFT JOINs through the M2O relational chain.
parts := strings.Split(c.Field, ".")
column := parts[0]
// TODO: Handle JOINs for dot notation paths
// For now, only support direct fields
if len(parts) > 1 {
// Placeholder for JOIN resolution
return dc.compileJoinedCondition(parts, c.Operator, c.Value)
}
@@ -285,7 +309,7 @@ func (dc *DomainCompiler) compileJoinedCondition(fieldPath []string, operator st
dc.joins = append(dc.joins, joinClause{
table: comodel.Table(),
alias: alias,
on: fmt.Sprintf("%s.%q = %q.\"id\"", currentAlias, f.Column(), alias),
on: fmt.Sprintf("%q.%q = %q.\"id\"", currentAlias, f.Column(), alias),
})
currentModel = comodel
@@ -293,8 +317,12 @@ func (dc *DomainCompiler) compileJoinedCondition(fieldPath []string, operator st
}
// The last segment is the actual field to filter on
leafField := fieldPath[len(fieldPath)-1]
qualifiedColumn := fmt.Sprintf("%s.%q", currentAlias, leafField)
leafFieldName := fieldPath[len(fieldPath)-1]
leafCol := leafFieldName
if lf := currentModel.GetField(leafFieldName); lf != nil {
leafCol = lf.Column()
}
qualifiedColumn := fmt.Sprintf("%q.%q", currentAlias, leafCol)
return dc.compileQualifiedCondition(qualifiedColumn, operator, value)
}
@@ -528,13 +556,8 @@ func (dc *DomainCompiler) compileAnyOp(column string, value Value, negate bool)
// Rebase parameter indices: shift them by the current param count
baseIdx := len(dc.params)
dc.params = append(dc.params, subParams...)
rebased := subWhere
// Replace $N with $(N+baseIdx) in the sub-where clause
for i := len(subParams); i >= 1; i-- {
old := fmt.Sprintf("$%d", i)
new := fmt.Sprintf("$%d", i+baseIdx)
rebased = strings.ReplaceAll(rebased, old, new)
}
// Replace $N with $(N+baseIdx) using regex to avoid $1 matching $10
rebased := rebaseParams(subWhere, baseIdx)
// Determine the join condition based on field type
var joinCond string
@@ -676,3 +699,14 @@ func wrapLikeValue(value Value) Value {
}
return "%" + s + "%"
}
// rebaseParams shifts $N placeholders in a SQL string by baseIdx.
// Uses regex to avoid $1 matching inside $10.
var paramRegex = regexp.MustCompile(`\$(\d+)`)
func rebaseParams(sql string, baseIdx int) string {
return paramRegex.ReplaceAllStringFunc(sql, func(match string) string {
n, _ := strconv.Atoi(match[1:])
return fmt.Sprintf("$%d", n+baseIdx)
})
}

View File

@@ -45,8 +45,9 @@ type Model struct {
checkCompany bool // Enforce multi-company record rules
// Hooks
BeforeCreate func(env *Environment, vals Values) error // Called before INSERT
DefaultGet func(env *Environment, fields []string) Values // Dynamic defaults (e.g., from DB)
BeforeCreate func(env *Environment, vals Values) error // Called before INSERT
BeforeWrite func(env *Environment, ids []int64, vals Values) error // Called before UPDATE — for state guards
DefaultGet func(env *Environment, fields []string) Values // Dynamic defaults (e.g., from DB)
Constraints []ConstraintFunc // Validation constraints
Methods map[string]MethodFunc // Named business methods
@@ -453,3 +454,32 @@ func (m *Model) Many2manyTableSQL() []string {
}
return stmts
}
// StateGuard returns a BeforeWrite function that prevents modifications on records
// in certain states, except for explicitly allowed fields.
// Eliminates the duplicated guard pattern across sale.order, purchase.order,
// account.move, and stock.picking.
func StateGuard(table, stateCondition string, allowedFields []string, errMsg string) func(env *Environment, ids []int64, vals Values) error {
allowed := make(map[string]bool, len(allowedFields))
for _, f := range allowedFields {
allowed[f] = true
}
return func(env *Environment, ids []int64, vals Values) error {
if _, changingState := vals["state"]; changingState {
return nil
}
var count int
err := env.Tx().QueryRow(env.Ctx(),
fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE id = ANY($1) AND %s`, table, stateCondition), ids,
).Scan(&count)
if err != nil || count == 0 {
return nil
}
for field := range vals {
if !allowed[field] {
return fmt.Errorf("%s: %s", table, errMsg)
}
}
return nil
}
}

View File

@@ -153,7 +153,7 @@ func (rs *Recordset) ReadGroup(domain Domain, groupby []string, aggregates []str
// Build ORDER BY
orderSQL := ""
if opt.Order != "" {
orderSQL = opt.Order
orderSQL = sanitizeOrderBy(opt.Order, m)
} else if len(gbCols) > 0 {
// Default: order by groupby columns
var orderParts []string

View File

@@ -2,6 +2,7 @@ package orm
import (
"fmt"
"log"
"strings"
)
@@ -265,18 +266,28 @@ func preprocessRelatedWrites(env *Environment, m *Model, ids []int64, vals Value
value := vals[fieldName]
delete(vals, fieldName) // Remove from vals — no local column
// Read FK IDs for all records
// Read FK IDs for all records in a single query
var fkIDs []int64
for _, id := range ids {
var fkID *int64
env.tx.QueryRow(env.ctx,
fmt.Sprintf(`SELECT %q FROM %q WHERE id = $1`, fkDef.Column(), m.Table()),
id,
).Scan(&fkID)
if fkID != nil && *fkID > 0 {
fkIDs = append(fkIDs, *fkID)
rows, err := env.tx.Query(env.ctx,
fmt.Sprintf(`SELECT %q FROM %q WHERE id = ANY($1) AND %q IS NOT NULL`,
fkDef.Column(), m.Table(), fkDef.Column()),
ids,
)
if err != nil {
delete(vals, fieldName)
continue
}
for rows.Next() {
var fkID int64
if err := rows.Scan(&fkID); err != nil {
log.Printf("orm: preprocessRelatedWrites scan error on %s.%s: %v", m.Name(), fieldName, err)
continue
}
if fkID > 0 {
fkIDs = append(fkIDs, fkID)
}
}
rows.Close()
if len(fkIDs) == 0 {
continue
@@ -315,6 +326,13 @@ func (rs *Recordset) Write(vals Values) error {
m := rs.model
// BeforeWrite hook — state guards, locked record checks etc.
if m.BeforeWrite != nil {
if err := m.BeforeWrite(rs.env, rs.ids, vals); err != nil {
return err
}
}
var setClauses []string
var args []interface{}
idx := 1
@@ -787,7 +805,7 @@ func (rs *Recordset) Search(domain Domain, opts ...SearchOpts) (*Recordset, erro
// Build query
order := m.order
if opt.Order != "" {
order = opt.Order
order = sanitizeOrderBy(opt.Order, m)
}
joinSQL := compiler.JoinSQL()
@@ -1103,6 +1121,72 @@ func toRecordID(v interface{}) (int64, bool) {
return 0, false
}
// sanitizeOrderBy validates an ORDER BY clause to prevent SQL injection.
// Only allows: field names (alphanumeric + underscore), ASC/DESC, NULLS FIRST/LAST, commas.
// Returns sanitized string or fallback to "id" if invalid.
func sanitizeOrderBy(order string, m *Model) string {
if order == "" {
return "id"
}
parts := strings.Split(order, ",")
var safe []string
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
tokens := strings.Fields(part)
if len(tokens) == 0 {
continue
}
// First token must be a valid field name or "table"."field"
col := tokens[0]
// Strip quotes for validation
cleanCol := strings.ReplaceAll(strings.ReplaceAll(col, "\"", ""), "'", "")
// Allow dot notation (table.field) but validate each part
colParts := strings.Split(cleanCol, ".")
valid := true
for _, cp := range colParts {
if !isValidIdentifier(cp) {
valid = false
break
}
}
if !valid {
continue // Skip this part entirely
}
// Remaining tokens must be ASC, DESC, NULLS, FIRST, LAST
safePart := col
for _, tok := range tokens[1:] {
upper := strings.ToUpper(tok)
switch upper {
case "ASC", "DESC", "NULLS", "FIRST", "LAST":
safePart += " " + upper
default:
// Invalid token — skip
}
}
safe = append(safe, safePart)
}
if len(safe) == 0 {
return "id"
}
return strings.Join(safe, ", ")
}
// isValidIdentifier checks if a string is a valid SQL identifier (letters, digits, underscore).
func isValidIdentifier(s string) bool {
if s == "" {
return false
}
for _, c := range s {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
return true
}
// qualifyOrderBy prefixes unqualified column names with the table name.
// "name, id desc" → "\"my_table\".name, \"my_table\".id desc"
func qualifyOrderBy(table, order string) string {

View File

@@ -70,8 +70,9 @@ func ApplyRecordRules(env *Environment, m *Model, domain Domain) Domain {
ORDER BY r.id`,
m.Name(), env.UID())
if err != nil {
log.Printf("orm: ir.rule query failed for %s: %v — denying access", m.Name(), err)
sp.Rollback(env.ctx)
return domain
return append(domain, Leaf("id", "=", -1)) // Deny all — no records match id=-1
}
type ruleRow struct {
@@ -207,7 +208,8 @@ func CheckRecordRuleAccess(env *Environment, m *Model, ids []int64, perm string)
var count int64
err := env.tx.QueryRow(env.ctx, query, args...).Scan(&count)
if err != nil {
return nil // Fail open on error
log.Printf("orm: record rule check failed for %s: %v", m.Name(), err)
return fmt.Errorf("orm: access denied on %s (record rule check failed)", m.Name())
}
if count < int64(len(ids)) {

View File

@@ -3,6 +3,7 @@ package server
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
)
@@ -145,10 +146,12 @@ func (s *Server) handleActionLoad(w http.ResponseWriter, r *http.Request) {
// Look up xml_id from ir_model_data
xmlID := ""
_ = s.pool.QueryRow(ctx,
if err := s.pool.QueryRow(ctx,
`SELECT module || '.' || name FROM ir_model_data
WHERE model = 'ir.actions.act_window' AND res_id = $1
LIMIT 1`, id).Scan(&xmlID)
LIMIT 1`, id).Scan(&xmlID); err != nil {
log.Printf("warning: action xml_id lookup failed for id=%d: %v", id, err)
}
// Build views array from view_mode string (e.g. "list,kanban,form" → [[nil,"list"],[nil,"kanban"],[nil,"form"]])
views := buildViewsFromMode(viewMode)

292
pkg/server/bank_import.go Normal file
View File

@@ -0,0 +1,292 @@
package server
import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
"odoo-go/pkg/orm"
)
// handleBankStatementImport imports bank statement lines from CSV data.
// Accepts JSON body with: journal_id, csv_data, column_mapping, has_header.
// After import, optionally triggers auto-matching against open invoices.
// Mirrors: odoo/addons/account/wizard/account_bank_statement_import.py
func (s *Server) handleBankStatementImport(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JSONRPCRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
return
}
var params struct {
JournalID int64 `json:"journal_id"`
CSVData string `json:"csv_data"`
HasHeader bool `json:"has_header"`
ColumnMapping bankColumnMapping `json:"column_mapping"`
AutoMatch bool `json:"auto_match"`
}
if err := json.Unmarshal(req.Params, &params); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
return
}
if params.JournalID == 0 || params.CSVData == "" {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "journal_id and csv_data are required"})
return
}
uid := int64(1)
companyID := int64(1)
if sess := GetSession(r); sess != nil {
uid = sess.UID
companyID = sess.CompanyID
}
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
Pool: s.pool,
UID: uid,
CompanyID: companyID,
})
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: "Internal error"})
return
}
defer env.Close()
// Parse CSV
reader := csv.NewReader(strings.NewReader(params.CSVData))
reader.LazyQuotes = true
reader.TrimLeadingSpace = true
// Try semicolon separator (common in European bank exports)
reader.Comma = detectDelimiter(params.CSVData)
var allRows [][]string
for {
row, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("CSV parse error: %v", err)})
return
}
allRows = append(allRows, row)
}
dataRows := allRows
if params.HasHeader && len(allRows) > 1 {
dataRows = allRows[1:]
}
// Create a bank statement header
statementRS := env.Model("account.bank.statement")
stmt, err := statementRS.Create(orm.Values{
"name": fmt.Sprintf("Import %s", time.Now().Format("2006-01-02 15:04")),
"journal_id": params.JournalID,
"company_id": companyID,
"date": time.Now().Format("2006-01-02"),
})
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Create statement: %v", err)})
return
}
stmtID := stmt.ID()
// Default column mapping
cm := params.ColumnMapping
if cm.Date < 0 {
cm.Date = 0
}
if cm.Amount < 0 {
cm.Amount = 1
}
if cm.Label < 0 {
cm.Label = 2
}
// Import lines
lineRS := env.Model("account.bank.statement.line")
var importedIDs []int64
var errors []importError
for rowIdx, row := range dataRows {
// Parse date
dateStr := safeCol(row, cm.Date)
date := parseFlexDate(dateStr)
if date == "" {
date = time.Now().Format("2006-01-02")
}
// Parse amount
amountStr := safeCol(row, cm.Amount)
amount := parseAmount(amountStr)
if amount == 0 {
continue // skip zero-amount rows
}
// Parse label/reference
label := safeCol(row, cm.Label)
if label == "" {
label = fmt.Sprintf("Line %d", rowIdx+1)
}
// Parse optional columns
partnerName := safeCol(row, cm.PartnerName)
accountNumber := safeCol(row, cm.AccountNumber)
vals := orm.Values{
"statement_id": stmtID,
"journal_id": params.JournalID,
"company_id": companyID,
"date": date,
"amount": amount,
"payment_ref": label,
"partner_name": partnerName,
"account_number": accountNumber,
"sequence": rowIdx + 1,
}
rec, err := lineRS.Create(vals)
if err != nil {
errors = append(errors, importError{Row: rowIdx + 1, Message: err.Error()})
log.Printf("bank_import: row %d error: %v", rowIdx+1, err)
continue
}
importedIDs = append(importedIDs, rec.ID())
}
// Auto-match against open invoices
matchCount := 0
if params.AutoMatch && len(importedIDs) > 0 {
stLineModel := orm.Registry.Get("account.bank.statement.line")
if stLineModel != nil {
if matchMethod, ok := stLineModel.Methods["button_match"]; ok {
matchRS := env.Model("account.bank.statement.line").Browse(importedIDs...)
if _, err := matchMethod(matchRS); err != nil {
log.Printf("bank_import: auto-match error: %v", err)
}
}
}
// Count how many were matched
env.Tx().QueryRow(env.Ctx(),
`SELECT COUNT(*) FROM account_bank_statement_line WHERE id = ANY($1) AND is_reconciled = true`,
importedIDs).Scan(&matchCount)
}
if err := env.Commit(); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Commit: %v", err)})
return
}
s.writeJSONRPC(w, req.ID, map[string]interface{}{
"statement_id": stmtID,
"imported": len(importedIDs),
"matched": matchCount,
"errors": errors,
}, nil)
}
// bankColumnMapping maps CSV columns to bank statement fields.
type bankColumnMapping struct {
Date int `json:"date"` // column index for date
Amount int `json:"amount"` // column index for amount
Label int `json:"label"` // column index for label/reference
PartnerName int `json:"partner_name"` // column index for partner name (-1 = skip)
AccountNumber int `json:"account_number"` // column index for account number (-1 = skip)
}
// detectDelimiter guesses the CSV delimiter (comma, semicolon, or tab).
func detectDelimiter(data string) rune {
firstLine := data
if idx := strings.IndexByte(data, '\n'); idx > 0 {
firstLine = data[:idx]
}
semicolons := strings.Count(firstLine, ";")
commas := strings.Count(firstLine, ",")
tabs := strings.Count(firstLine, "\t")
if semicolons > commas && semicolons > tabs {
return ';'
}
if tabs > commas {
return '\t'
}
return ','
}
// safeCol returns the value at index i, or "" if out of bounds.
func safeCol(row []string, i int) string {
if i < 0 || i >= len(row) {
return ""
}
return strings.TrimSpace(row[i])
}
// parseFlexDate tries multiple date formats and returns YYYY-MM-DD.
func parseFlexDate(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return ""
}
formats := []string{
"2006-01-02",
"02.01.2006", // DD.MM.YYYY (common in EU)
"01/02/2006", // MM/DD/YYYY
"02/01/2006", // DD/MM/YYYY
"2006/01/02",
"Jan 2, 2006",
"2 Jan 2006",
"02-01-2006",
"01-02-2006",
time.RFC3339,
}
for _, f := range formats {
if t, err := time.Parse(f, s); err == nil {
return t.Format("2006-01-02")
}
}
return ""
}
// parseAmount parses a monetary amount string, handling comma/dot decimals and negative formats.
func parseAmount(s string) float64 {
s = strings.TrimSpace(s)
if s == "" {
return 0
}
// Remove currency symbols and whitespace
s = strings.NewReplacer("€", "", "$", "", "£", "", " ", "", "\u00a0", "").Replace(s)
// Handle European format: 1.234,56 → 1234.56
if strings.Contains(s, ",") && strings.Contains(s, ".") {
if strings.LastIndex(s, ",") > strings.LastIndex(s, ".") {
// comma is decimal: 1.234,56
s = strings.ReplaceAll(s, ".", "")
s = strings.ReplaceAll(s, ",", ".")
} else {
// dot is decimal: 1,234.56
s = strings.ReplaceAll(s, ",", "")
}
} else if strings.Contains(s, ",") {
// Only comma: assume decimal separator
s = strings.ReplaceAll(s, ",", ".")
}
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0
}
return v
}

241
pkg/server/bus.go Normal file
View File

@@ -0,0 +1,241 @@
package server
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
)
// Bus implements a simple long-polling message bus for Discuss.
// Mirrors: odoo/addons/bus/models/bus.py ImBus
//
// Channels subscribe to notifications. A long-poll request blocks until
// a notification arrives or the timeout expires.
type Bus struct {
mu sync.Mutex
channels map[int64][]chan busNotification
lastID int64
}
type busNotification struct {
ID int64 `json:"id"`
Channel string `json:"channel"`
Message interface{} `json:"message"`
}
// NewBus creates a new message bus.
func NewBus() *Bus {
return &Bus{
channels: make(map[int64][]chan busNotification),
}
}
// Notify sends a notification to all subscribers of a channel.
func (b *Bus) Notify(channelID int64, channel string, message interface{}) {
b.mu.Lock()
b.lastID++
notif := busNotification{
ID: b.lastID,
Channel: channel,
Message: message,
}
subs := b.channels[channelID]
b.mu.Unlock()
for _, ch := range subs {
select {
case ch <- notif:
default:
// subscriber buffer full, skip
}
}
}
// Subscribe creates a subscription for a partner's channels.
func (b *Bus) Subscribe(partnerID int64) chan busNotification {
ch := make(chan busNotification, 10)
b.mu.Lock()
b.channels[partnerID] = append(b.channels[partnerID], ch)
b.mu.Unlock()
return ch
}
// Unsubscribe removes a subscription.
func (b *Bus) Unsubscribe(partnerID int64, ch chan busNotification) {
b.mu.Lock()
defer b.mu.Unlock()
subs := b.channels[partnerID]
for i, s := range subs {
if s == ch {
b.channels[partnerID] = append(subs[:i], subs[i+1:]...)
close(ch)
return
}
}
}
// registerBusRoutes adds the long-polling endpoint.
func (s *Server) registerBusRoutes() {
if s.bus == nil {
s.bus = NewBus()
}
s.mux.HandleFunc("/longpolling/poll", s.handleBusPoll)
s.mux.HandleFunc("/discuss/channel/messages", s.handleDiscussMessages)
s.mux.HandleFunc("/discuss/channel/list", s.handleDiscussChannelList)
}
// handleBusPoll implements long-polling for real-time notifications.
// Mirrors: odoo/addons/bus/controllers/main.py poll()
func (s *Server) handleBusPoll(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
sess := GetSession(r)
if sess == nil {
writeJSON(w, []interface{}{})
return
}
// Get partner ID
var partnerID int64
s.pool.QueryRow(r.Context(),
`SELECT COALESCE(partner_id, 0) FROM res_users WHERE id = $1`, sess.UID,
).Scan(&partnerID)
if partnerID == 0 {
writeJSON(w, []interface{}{})
return
}
// Subscribe and wait for notifications (max 30s)
ch := s.bus.Subscribe(partnerID)
defer s.bus.Unsubscribe(partnerID, ch)
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
select {
case notif := <-ch:
writeJSON(w, []busNotification{notif})
case <-ctx.Done():
writeJSON(w, []interface{}{}) // timeout, empty response
}
}
// handleDiscussMessages fetches messages for a channel via JSON-RPC.
func (s *Server) handleDiscussMessages(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
sess := GetSession(r)
if sess == nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
return
}
var req JSONRPCRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
return
}
var params struct {
ChannelID int64 `json:"channel_id"`
Limit int `json:"limit"`
}
if err := json.Unmarshal(req.Params, &params); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
return
}
if params.Limit <= 0 {
params.Limit = 50
}
rows, err := s.pool.Query(r.Context(),
`SELECT m.id, m.body, m.date, m.author_id, COALESCE(p.name, '')
FROM mail_message m
LEFT JOIN res_partner p ON p.id = m.author_id
WHERE m.model = 'mail.channel' AND m.res_id = $1
ORDER BY m.id DESC LIMIT $2`, params.ChannelID, params.Limit)
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Query: %v", err)})
return
}
defer rows.Close()
var messages []map[string]interface{}
for rows.Next() {
var id, authorID int64
var body, authorName string
var date interface{}
if err := rows.Scan(&id, &body, &date, &authorID, &authorName); err != nil {
continue
}
msg := map[string]interface{}{
"id": id, "body": body, "date": date,
}
if authorID > 0 {
msg["author_id"] = []interface{}{authorID, authorName}
} else {
msg["author_id"] = false
}
messages = append(messages, msg)
}
if messages == nil {
messages = []map[string]interface{}{}
}
s.writeJSONRPC(w, req.ID, messages, nil)
}
// handleDiscussChannelList returns channels the current user is member of.
func (s *Server) handleDiscussChannelList(w http.ResponseWriter, r *http.Request) {
sess := GetSession(r)
if sess == nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
return
}
var partnerID int64
s.pool.QueryRow(r.Context(),
`SELECT COALESCE(partner_id, 0) FROM res_users WHERE id = $1`, sess.UID,
).Scan(&partnerID)
rows, err := s.pool.Query(r.Context(),
`SELECT c.id, c.name, c.channel_type,
(SELECT COUNT(*) FROM mail_channel_member WHERE channel_id = c.id) AS members
FROM mail_channel c
JOIN mail_channel_member cm ON cm.channel_id = c.id AND cm.partner_id = $1
WHERE c.active = true
ORDER BY c.last_message_date DESC NULLS LAST`, partnerID)
if err != nil {
log.Printf("discuss: channel list error: %v", err)
writeJSON(w, []interface{}{})
return
}
defer rows.Close()
var channels []map[string]interface{}
for rows.Next() {
var id int64
var name, channelType string
var members int64
if err := rows.Scan(&id, &name, &channelType, &members); err != nil {
continue
}
channels = append(channels, map[string]interface{}{
"id": id, "name": name, "channel_type": channelType, "member_count": members,
})
}
if channels == nil {
channels = []map[string]interface{}{}
}
writeJSON(w, channels)
}

View File

@@ -6,37 +6,45 @@ import (
"fmt"
"net/http"
"github.com/xuri/excelize/v2"
"odoo-go/pkg/orm"
)
// handleExportCSV exports records as CSV.
// Mirrors: odoo/addons/web/controllers/export.py ExportController
func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// exportField describes a field in an export request.
type exportField struct {
Name string `json:"name"`
Label string `json:"label"`
}
// exportData holds the parsed and fetched data for an export operation.
type exportData struct {
Model string
FieldNames []string
Headers []string
Records []orm.Values
}
// parseExportRequest parses the common request/params/env/search logic shared by CSV and XLSX export.
func (s *Server) parseExportRequest(w http.ResponseWriter, r *http.Request) (*exportData, error) {
var req JSONRPCRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
return
return nil, err
}
var params struct {
Data struct {
Model string `json:"model"`
Fields []exportField `json:"fields"`
Domain []interface{} `json:"domain"`
IDs []float64 `json:"ids"`
Model string `json:"model"`
Fields []exportField `json:"fields"`
Domain []interface{} `json:"domain"`
IDs []float64 `json:"ids"`
} `json:"data"`
}
if err := json.Unmarshal(req.Params, &params); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
return
return nil, err
}
// Extract UID from session
uid := int64(1)
companyID := int64(1)
if sess := GetSession(r); sess != nil {
@@ -45,42 +53,31 @@ func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
}
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
Pool: s.pool,
UID: uid,
CompanyID: companyID,
Pool: s.pool, UID: uid, CompanyID: companyID,
})
if err != nil {
http.Error(w, "Internal error", http.StatusInternalServerError)
return
return nil, err
}
defer env.Close()
rs := env.Model(params.Data.Model)
// Determine which record IDs to export
var ids []int64
if len(params.Data.IDs) > 0 {
for _, id := range params.Data.IDs {
ids = append(ids, int64(id))
}
} else {
// Search with domain
domain := parseDomain([]interface{}{params.Data.Domain})
found, err := rs.Search(domain, orm.SearchOpts{Limit: 10000})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
return nil, err
}
ids = found.IDs()
}
if len(ids) == 0 {
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", params.Data.Model))
return
}
// Extract field names
var fieldNames []string
var headers []string
for _, f := range params.Data.Fields {
@@ -92,42 +89,89 @@ func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
headers = append(headers, label)
}
// Read records
records, err := rs.Browse(ids...).Read(fieldNames)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
var records []orm.Values
if len(ids) > 0 {
records, err = rs.Browse(ids...).Read(fieldNames)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil, err
}
}
if err := env.Commit(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil, err
}
return &exportData{
Model: params.Data.Model, FieldNames: fieldNames, Headers: headers, Records: records,
}, nil
}
// handleExportCSV exports records as CSV.
// Mirrors: odoo/addons/web/controllers/export.py ExportController
func (s *Server) handleExportCSV(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
data, err := s.parseExportRequest(w, r)
if err != nil {
return
}
// Write CSV
w.Header().Set("Content-Type", "text/csv; charset=utf-8")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", params.Data.Model))
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.csv", data.Model))
writer := csv.NewWriter(w)
defer writer.Flush()
// Header row
writer.Write(headers)
// Data rows
for _, rec := range records {
row := make([]string, len(fieldNames))
for i, fname := range fieldNames {
writer.Write(data.Headers)
for _, rec := range data.Records {
row := make([]string, len(data.FieldNames))
for i, fname := range data.FieldNames {
row[i] = formatCSVValue(rec[fname])
}
writer.Write(row)
}
}
// exportField describes a field in an export request.
type exportField struct {
Name string `json:"name"`
Label string `json:"label"`
// handleExportXLSX exports records as XLSX (Excel).
// Mirrors: odoo/addons/web/controllers/export.py ExportXlsxController
func (s *Server) handleExportXLSX(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
data, err := s.parseExportRequest(w, r)
if err != nil {
return
}
f := excelize.NewFile()
sheet := "Sheet1"
headerStyle, _ := f.NewStyle(&excelize.Style{
Font: &excelize.Font{Bold: true},
})
for i, h := range data.Headers {
cell, _ := excelize.CoordinatesToCellName(i+1, 1)
f.SetCellValue(sheet, cell, h)
f.SetCellStyle(sheet, cell, cell, headerStyle)
}
for rowIdx, rec := range data.Records {
for colIdx, fname := range data.FieldNames {
cell, _ := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
f.SetCellValue(sheet, cell, formatCSVValue(rec[fname]))
}
}
w.Header().Set("Content-Type", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.xlsx", data.Model))
f.Write(w)
}
// formatCSVValue converts a field value to a CSV string.

View File

@@ -2,6 +2,7 @@ package server
import (
"fmt"
"log"
"net/http"
"strconv"
"strings"
@@ -55,9 +56,11 @@ func (s *Server) handleImage(w http.ResponseWriter, r *http.Request) {
table := m.Table()
var data []byte
ctx := r.Context()
_ = s.pool.QueryRow(ctx,
if err := s.pool.QueryRow(ctx,
fmt.Sprintf(`SELECT "%s" FROM "%s" WHERE id = $1`, f.Column(), table), id,
).Scan(&data)
).Scan(&data); err != nil {
log.Printf("warning: image query failed for %s.%s id=%d: %v", model, field, id, err)
}
if len(data) > 0 {
// Detect content type
contentType := http.DetectContentType(data)
@@ -76,9 +79,11 @@ func (s *Server) handleImage(w http.ResponseWriter, r *http.Request) {
m := orm.Registry.Get(model)
if m != nil {
var name string
_ = s.pool.QueryRow(r.Context(),
if err := s.pool.QueryRow(r.Context(),
fmt.Sprintf(`SELECT COALESCE(name, '') FROM "%s" WHERE id = $1`, m.Table()), id,
).Scan(&name)
).Scan(&name); err != nil {
log.Printf("warning: image name lookup failed for %s id=%d: %v", model, id, err)
}
if len(name) > 0 {
initial = strings.ToUpper(name[:1])
}

223
pkg/server/import.go Normal file
View File

@@ -0,0 +1,223 @@
package server
import (
"encoding/csv"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"odoo-go/pkg/orm"
)
// handleImportCSV imports records from a CSV file into any model.
// Accepts JSON body with: model, fields (mapping), csv_data (raw CSV string).
// Mirrors: odoo/addons/base_import/controllers/main.py ImportController
func (s *Server) handleImportCSV(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req JSONRPCRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
return
}
var params struct {
Model string `json:"model"`
Fields []importFieldMap `json:"fields"`
CSVData string `json:"csv_data"`
HasHeader bool `json:"has_header"`
DryRun bool `json:"dry_run"`
}
if err := json.Unmarshal(req.Params, &params); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid params"})
return
}
if params.Model == "" || len(params.Fields) == 0 || params.CSVData == "" {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "model, fields, and csv_data are required"})
return
}
// Verify model exists
m := orm.Registry.Get(params.Model)
if m == nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("Unknown model: %s", params.Model)})
return
}
// Parse CSV
reader := csv.NewReader(strings.NewReader(params.CSVData))
reader.LazyQuotes = true
reader.TrimLeadingSpace = true
var allRows [][]string
for {
row, err := reader.Read()
if err == io.EOF {
break
}
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: fmt.Sprintf("CSV parse error: %v", err)})
return
}
allRows = append(allRows, row)
}
if len(allRows) == 0 {
s.writeJSONRPC(w, req.ID, map[string]interface{}{"ids": []int64{}, "count": 0}, nil)
return
}
// Skip header row if present
dataRows := allRows
if params.HasHeader && len(allRows) > 1 {
dataRows = allRows[1:]
}
// Build field mapping: CSV column index → ORM field name
type colMapping struct {
colIndex int
fieldName string
fieldType orm.FieldType
}
var mappings []colMapping
for _, fm := range params.Fields {
if fm.FieldName == "" || fm.ColumnIndex < 0 {
continue
}
f := m.GetField(fm.FieldName)
if f == nil {
continue // skip unknown fields
}
mappings = append(mappings, colMapping{
colIndex: fm.ColumnIndex,
fieldName: fm.FieldName,
fieldType: f.Type,
})
}
if len(mappings) == 0 {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "No valid field mappings"})
return
}
uid := int64(1)
companyID := int64(1)
if sess := GetSession(r); sess != nil {
uid = sess.UID
companyID = sess.CompanyID
}
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
Pool: s.pool,
UID: uid,
CompanyID: companyID,
})
if err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: "Internal error"})
return
}
defer env.Close()
rs := env.Model(params.Model)
var createdIDs []int64
var errors []importError
for rowIdx, row := range dataRows {
vals := make(orm.Values)
for _, cm := range mappings {
if cm.colIndex >= len(row) {
continue
}
raw := strings.TrimSpace(row[cm.colIndex])
if raw == "" {
continue
}
vals[cm.fieldName] = coerceImportValue(raw, cm.fieldType)
}
if len(vals) == 0 {
continue
}
if params.DryRun {
continue // validate only, don't create
}
rec, err := rs.Create(vals)
if err != nil {
errors = append(errors, importError{
Row: rowIdx + 1,
Message: err.Error(),
})
log.Printf("import: row %d error: %v", rowIdx+1, err)
continue
}
createdIDs = append(createdIDs, rec.ID())
}
if err := env.Commit(); err != nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32603, Message: fmt.Sprintf("Commit error: %v", err)})
return
}
result := map[string]interface{}{
"ids": createdIDs,
"count": len(createdIDs),
"errors": errors,
"dry_run": params.DryRun,
}
s.writeJSONRPC(w, req.ID, result, nil)
}
// importFieldMap maps a CSV column to an ORM field.
type importFieldMap struct {
ColumnIndex int `json:"column_index"`
FieldName string `json:"field_name"`
}
// importError describes a per-row import error.
type importError struct {
Row int `json:"row"`
Message string `json:"message"`
}
// coerceImportValue converts a raw CSV string to the appropriate Go type for ORM Create.
func coerceImportValue(raw string, ft orm.FieldType) interface{} {
switch ft {
case orm.TypeInteger:
v, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return nil
}
return v
case orm.TypeFloat, orm.TypeMonetary:
// Handle comma as decimal separator
raw = strings.ReplaceAll(raw, ",", ".")
v, err := strconv.ParseFloat(raw, 64)
if err != nil {
return nil
}
return v
case orm.TypeBoolean:
lower := strings.ToLower(raw)
return lower == "true" || lower == "1" || lower == "yes" || lower == "ja"
case orm.TypeMany2one:
// Try as integer ID first, then as name_search later
v, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return raw // pass as string, ORM may handle name_create
}
return v
default:
return raw
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"log"
"net/http"
"path/filepath"
"strings"
"time"
)
@@ -43,13 +44,19 @@ func (w *statusWriter) WriteHeader(code int) {
func AuthMiddleware(store *SessionStore, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Public endpoints (no auth required)
path := r.URL.Path
path := filepath.Clean(r.URL.Path)
if path == "/health" ||
path == "/web/login" ||
path == "/web/session/authenticate" ||
path == "/web/session/logout" ||
strings.HasPrefix(path, "/web/database/") ||
path == "/web/database/manager" ||
path == "/web/database/create" ||
path == "/web/database/list" ||
path == "/web/webclient/version_info" ||
path == "/web/setup/wizard" ||
path == "/web/setup/wizard/save" ||
path == "/web/portal/signup" ||
path == "/web/portal/reset_password" ||
strings.Contains(path, "/static/") {
next.ServeHTTP(w, r)
return
@@ -58,8 +65,14 @@ func AuthMiddleware(store *SessionStore, next http.Handler) http.Handler {
// Check session cookie
cookie, err := r.Cookie("session_id")
if err != nil || cookie.Value == "" {
// Also check JSON-RPC params for session_id (Odoo sends it both ways)
next.ServeHTTP(w, r) // For now, allow through — UID defaults to 1
// No session cookie — reject protected endpoints
if r.Header.Get("Content-Type") == "application/json" ||
strings.HasPrefix(path, "/web/dataset/") ||
strings.HasPrefix(path, "/jsonrpc") {
http.Error(w, `{"jsonrpc":"2.0","error":{"code":100,"message":"Session expired"}}`, http.StatusUnauthorized)
} else {
http.Redirect(w, r, "/web/login", http.StatusFound)
}
return
}

379
pkg/server/portal.go Normal file
View File

@@ -0,0 +1,379 @@
// Package server — Portal controllers for external (customer/supplier) access.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal
package server
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
)
// registerPortalRoutes registers all /my/* portal endpoints.
func (s *Server) registerPortalRoutes() {
s.mux.HandleFunc("/my", s.handlePortalHome)
s.mux.HandleFunc("/my/", s.handlePortalDispatch)
s.mux.HandleFunc("/my/home", s.handlePortalHome)
s.mux.HandleFunc("/my/invoices", s.handlePortalInvoices)
s.mux.HandleFunc("/my/orders", s.handlePortalOrders)
s.mux.HandleFunc("/my/pickings", s.handlePortalPickings)
s.mux.HandleFunc("/my/account", s.handlePortalAccount)
}
// handlePortalDispatch routes /my/* sub-paths to the correct handler.
func (s *Server) handlePortalDispatch(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/my/home":
s.handlePortalHome(w, r)
case "/my/invoices":
s.handlePortalInvoices(w, r)
case "/my/orders":
s.handlePortalOrders(w, r)
case "/my/pickings":
s.handlePortalPickings(w, r)
case "/my/account":
s.handlePortalAccount(w, r)
default:
s.handlePortalHome(w, r)
}
}
// portalPartnerID resolves the partner_id of the currently logged-in portal user.
// Returns (partnerID, error). If session is missing, writes an error response and returns 0.
func (s *Server) portalPartnerID(w http.ResponseWriter, r *http.Request) (int64, bool) {
sess := GetSession(r)
if sess == nil {
writePortalError(w, http.StatusUnauthorized, "Not authenticated")
return 0, false
}
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
var partnerID int64
err := s.pool.QueryRow(ctx,
`SELECT partner_id FROM res_users WHERE id = $1 AND active = true`,
sess.UID).Scan(&partnerID)
if err != nil {
log.Printf("portal: cannot resolve partner_id for uid=%d: %v", sess.UID, err)
writePortalError(w, http.StatusForbidden, "User not found")
return 0, false
}
return partnerID, true
}
// handlePortalHome returns the portal dashboard with document counts.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.home()
func (s *Server) handlePortalHome(w http.ResponseWriter, r *http.Request) {
partnerID, ok := s.portalPartnerID(w, r)
if !ok {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
var invoiceCount, orderCount, pickingCount int64
// Count invoices (account.move with move_type in ('out_invoice','out_refund'))
err := s.pool.QueryRow(ctx,
`SELECT COUNT(*) FROM account_move
WHERE partner_id = $1 AND move_type IN ('out_invoice','out_refund')
AND state = 'posted'`, partnerID).Scan(&invoiceCount)
if err != nil {
log.Printf("portal: invoice count error: %v", err)
}
// Count sale orders (confirmed or done)
err = s.pool.QueryRow(ctx,
`SELECT COUNT(*) FROM sale_order
WHERE partner_id = $1 AND state IN ('sale','done')`, partnerID).Scan(&orderCount)
if err != nil {
log.Printf("portal: order count error: %v", err)
}
// Count pickings (stock.picking)
err = s.pool.QueryRow(ctx,
`SELECT COUNT(*) FROM stock_picking
WHERE partner_id = $1 AND state != 'cancel'`, partnerID).Scan(&pickingCount)
if err != nil {
log.Printf("portal: picking count error: %v", err)
}
writePortalJSON(w, map[string]interface{}{
"counters": map[string]int64{
"invoice_count": invoiceCount,
"order_count": orderCount,
"picking_count": pickingCount,
},
})
}
// handlePortalInvoices lists invoices for the current portal user.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_invoices()
func (s *Server) handlePortalInvoices(w http.ResponseWriter, r *http.Request) {
partnerID, ok := s.portalPartnerID(w, r)
if !ok {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
rows, err := s.pool.Query(ctx,
`SELECT m.id, m.name, m.move_type, m.state, m.date,
m.amount_total::float8, m.amount_residual::float8,
m.payment_state, COALESCE(m.ref, '')
FROM account_move m
WHERE m.partner_id = $1
AND m.move_type IN ('out_invoice','out_refund')
AND m.state = 'posted'
ORDER BY m.date DESC
LIMIT 80`, partnerID)
if err != nil {
log.Printf("portal: invoice query error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to load invoices")
return
}
defer rows.Close()
var invoices []map[string]interface{}
for rows.Next() {
var id int64
var name, moveType, state, paymentState, ref string
var date time.Time
var amountTotal, amountResidual float64
if err := rows.Scan(&id, &name, &moveType, &state, &date,
&amountTotal, &amountResidual, &paymentState, &ref); err != nil {
log.Printf("portal: invoice scan error: %v", err)
continue
}
invoices = append(invoices, map[string]interface{}{
"id": id,
"name": name,
"move_type": moveType,
"state": state,
"date": date.Format("2006-01-02"),
"amount_total": amountTotal,
"amount_residual": amountResidual,
"payment_state": paymentState,
"ref": ref,
})
}
if invoices == nil {
invoices = []map[string]interface{}{}
}
writePortalJSON(w, map[string]interface{}{"invoices": invoices})
}
// handlePortalOrders lists sale orders for the current portal user.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_orders()
func (s *Server) handlePortalOrders(w http.ResponseWriter, r *http.Request) {
partnerID, ok := s.portalPartnerID(w, r)
if !ok {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
rows, err := s.pool.Query(ctx,
`SELECT so.id, so.name, so.state, so.date_order,
so.amount_total::float8, COALESCE(so.invoice_status, ''),
COALESCE(so.delivery_status, '')
FROM sale_order so
WHERE so.partner_id = $1
AND so.state IN ('sale','done')
ORDER BY so.date_order DESC
LIMIT 80`, partnerID)
if err != nil {
log.Printf("portal: order query error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to load orders")
return
}
defer rows.Close()
var orders []map[string]interface{}
for rows.Next() {
var id int64
var name, state, invoiceStatus, deliveryStatus string
var dateOrder time.Time
var amountTotal float64
if err := rows.Scan(&id, &name, &state, &dateOrder,
&amountTotal, &invoiceStatus, &deliveryStatus); err != nil {
log.Printf("portal: order scan error: %v", err)
continue
}
orders = append(orders, map[string]interface{}{
"id": id,
"name": name,
"state": state,
"date_order": dateOrder.Format("2006-01-02 15:04:05"),
"amount_total": amountTotal,
"invoice_status": invoiceStatus,
"delivery_status": deliveryStatus,
})
}
if orders == nil {
orders = []map[string]interface{}{}
}
writePortalJSON(w, map[string]interface{}{"orders": orders})
}
// handlePortalPickings lists stock pickings for the current portal user.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.portal_my_pickings()
func (s *Server) handlePortalPickings(w http.ResponseWriter, r *http.Request) {
partnerID, ok := s.portalPartnerID(w, r)
if !ok {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
rows, err := s.pool.Query(ctx,
`SELECT sp.id, sp.name, sp.state, sp.scheduled_date,
COALESCE(sp.origin, ''),
COALESCE(spt.name, '') AS picking_type_name
FROM stock_picking sp
LEFT JOIN stock_picking_type spt ON spt.id = sp.picking_type_id
WHERE sp.partner_id = $1
AND sp.state != 'cancel'
ORDER BY sp.scheduled_date DESC
LIMIT 80`, partnerID)
if err != nil {
log.Printf("portal: picking query error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to load pickings")
return
}
defer rows.Close()
var pickings []map[string]interface{}
for rows.Next() {
var id int64
var name, state, origin, pickingTypeName string
var scheduledDate time.Time
if err := rows.Scan(&id, &name, &state, &scheduledDate,
&origin, &pickingTypeName); err != nil {
log.Printf("portal: picking scan error: %v", err)
continue
}
pickings = append(pickings, map[string]interface{}{
"id": id,
"name": name,
"state": state,
"scheduled_date": scheduledDate.Format("2006-01-02 15:04:05"),
"origin": origin,
"picking_type_name": pickingTypeName,
})
}
if pickings == nil {
pickings = []map[string]interface{}{}
}
writePortalJSON(w, map[string]interface{}{"pickings": pickings})
}
// handlePortalAccount returns/updates the portal user's profile.
// GET: returns user profile. POST: updates name/email/phone/street/city/zip.
// Mirrors: odoo/addons/portal/controllers/portal.py CustomerPortal.account()
func (s *Server) handlePortalAccount(w http.ResponseWriter, r *http.Request) {
partnerID, ok := s.portalPartnerID(w, r)
if !ok {
return
}
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
if r.Method == http.MethodPost {
// Update profile
var body struct {
Name *string `json:"name"`
Email *string `json:"email"`
Phone *string `json:"phone"`
Street *string `json:"street"`
City *string `json:"city"`
Zip *string `json:"zip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
return
}
// Build SET clause dynamically with parameterized placeholders
sets := make([]string, 0, 6)
args := make([]interface{}, 0, 7)
idx := 1
addField := func(col string, val *string) {
if val != nil {
sets = append(sets, fmt.Sprintf("%s = $%d", col, idx))
args = append(args, *val)
idx++
}
}
addField("name", body.Name)
addField("email", body.Email)
addField("phone", body.Phone)
addField("street", body.Street)
addField("city", body.City)
addField("zip", body.Zip)
if len(sets) > 0 {
args = append(args, partnerID)
query := "UPDATE res_partner SET "
for j, set := range sets {
if j > 0 {
query += ", "
}
query += set
}
query += fmt.Sprintf(" WHERE id = $%d", idx)
if _, err := s.pool.Exec(ctx, query, args...); err != nil {
log.Printf("portal: account update error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Update failed")
return
}
}
writePortalJSON(w, map[string]interface{}{"success": true})
return
}
// GET — return profile
var name, email, phone, street, city, zip string
err := s.pool.QueryRow(ctx,
`SELECT COALESCE(name,''), COALESCE(email,''), COALESCE(phone,''),
COALESCE(street,''), COALESCE(city,''), COALESCE(zip,'')
FROM res_partner WHERE id = $1`, partnerID).Scan(
&name, &email, &phone, &street, &city, &zip)
if err != nil {
log.Printf("portal: account read error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to load profile")
return
}
writePortalJSON(w, map[string]interface{}{
"name": name,
"email": email,
"phone": phone,
"street": street,
"city": city,
"zip": zip,
})
}
// --- Helpers ---
func writePortalJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(data)
}
func writePortalError(w http.ResponseWriter, status int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{"error": message})
}

313
pkg/server/portal_signup.go Normal file
View File

@@ -0,0 +1,313 @@
// Package server — Portal signup and password reset.
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome
package server
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"odoo-go/pkg/tools"
)
// registerPortalSignupRoutes registers /web/portal/* public endpoints.
func (s *Server) registerPortalSignupRoutes() {
s.mux.HandleFunc("/web/portal/signup", s.handlePortalSignup)
s.mux.HandleFunc("/web/portal/reset_password", s.handlePortalResetPassword)
}
// handlePortalSignup creates a new portal user with share=true and a matching res.partner.
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome.web_auth_signup()
func (s *Server) handlePortalSignup(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writePortalError(w, http.StatusMethodNotAllowed, "POST required")
return
}
var body struct {
Name string `json:"name"`
Email string `json:"email"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
return
}
// Validate required fields
body.Name = strings.TrimSpace(body.Name)
body.Email = strings.TrimSpace(body.Email)
if body.Name == "" || body.Email == "" || body.Password == "" {
writePortalError(w, http.StatusBadRequest, "Name, email, and password are required")
return
}
if len(body.Password) < 8 {
writePortalError(w, http.StatusBadRequest, "Password must be at least 8 characters")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Check if login already exists
var exists bool
err := s.pool.QueryRow(ctx,
`SELECT EXISTS(SELECT 1 FROM res_users WHERE login = $1)`, body.Email).Scan(&exists)
if err != nil {
log.Printf("portal signup: check existing user error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
if exists {
writePortalError(w, http.StatusConflict, "An account with this email already exists")
return
}
// Hash password
hashedPw, err := tools.HashPassword(body.Password)
if err != nil {
log.Printf("portal signup: hash password error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
// Get default company
var companyID int64
err = s.pool.QueryRow(ctx,
`SELECT id FROM res_company WHERE active = true ORDER BY id LIMIT 1`).Scan(&companyID)
if err != nil {
log.Printf("portal signup: get company error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
// Begin transaction — create partner + user atomically
tx, err := s.pool.Begin(ctx)
if err != nil {
log.Printf("portal signup: begin tx error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
defer tx.Rollback(ctx)
// Create res.partner
var partnerID int64
err = tx.QueryRow(ctx,
`INSERT INTO res_partner (name, email, active, company_id, customer_rank)
VALUES ($1, $2, true, $3, 1)
RETURNING id`, body.Name, body.Email, companyID).Scan(&partnerID)
if err != nil {
log.Printf("portal signup: create partner error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
return
}
// Create res.users with share=true
var userID int64
err = tx.QueryRow(ctx,
`INSERT INTO res_users (login, password, active, partner_id, company_id, share)
VALUES ($1, $2, true, $3, $4, true)
RETURNING id`, body.Email, hashedPw, partnerID, companyID).Scan(&userID)
if err != nil {
log.Printf("portal signup: create user error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
return
}
// Add user to group_portal (not group_user)
var groupPortalID int64
err = tx.QueryRow(ctx,
`SELECT g.id FROM res_groups g
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
WHERE imd.module = 'base' AND imd.name = 'group_portal'`).Scan(&groupPortalID)
if err != nil {
// group_portal might not exist yet — create it
err = tx.QueryRow(ctx,
`INSERT INTO res_groups (name) VALUES ('Portal') RETURNING id`).Scan(&groupPortalID)
if err != nil {
log.Printf("portal signup: create group_portal error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
return
}
_, err = tx.Exec(ctx,
`INSERT INTO ir_model_data (module, name, model, res_id)
VALUES ('base', 'group_portal', 'res.groups', $1)
ON CONFLICT DO NOTHING`, groupPortalID)
if err != nil {
log.Printf("portal signup: create group_portal xmlid error: %v", err)
}
}
_, err = tx.Exec(ctx,
`INSERT INTO res_groups_res_users_rel (res_groups_id, res_users_id)
VALUES ($1, $2) ON CONFLICT DO NOTHING`, groupPortalID, userID)
if err != nil {
log.Printf("portal signup: add user to group_portal error: %v", err)
}
if err := tx.Commit(ctx); err != nil {
log.Printf("portal signup: commit error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to create account")
return
}
log.Printf("portal signup: created portal user id=%d login=%s partner_id=%d",
userID, body.Email, partnerID)
writePortalJSON(w, map[string]interface{}{
"success": true,
"user_id": userID,
"partner_id": partnerID,
"message": "Account created successfully",
})
}
// handlePortalResetPassword handles password reset requests.
// POST with {"email":"..."}: generates a reset token and sends an email.
// POST with {"token":"...","password":"..."}: resets the password.
// Mirrors: odoo/addons/auth_signup/controllers/main.py AuthSignupHome.web_auth_reset_password()
func (s *Server) handlePortalResetPassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writePortalError(w, http.StatusMethodNotAllowed, "POST required")
return
}
var body struct {
Email string `json:"email"`
Token string `json:"token"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writePortalError(w, http.StatusBadRequest, "Invalid JSON")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Phase 2: Token + new password → reset
if body.Token != "" && body.Password != "" {
s.handleResetWithToken(w, ctx, body.Token, body.Password)
return
}
// Phase 1: Email → generate token + send email
if body.Email == "" {
writePortalError(w, http.StatusBadRequest, "Email is required")
return
}
s.handleResetRequest(w, ctx, strings.TrimSpace(body.Email))
}
// handleResetRequest generates a reset token and sends it via email.
func (s *Server) handleResetRequest(w http.ResponseWriter, ctx context.Context, email string) {
// Look up user
var uid int64
err := s.pool.QueryRow(ctx,
`SELECT id FROM res_users WHERE login = $1 AND active = true`, email).Scan(&uid)
if err != nil {
// Don't reveal whether the email exists — always return success
writePortalJSON(w, map[string]interface{}{
"success": true,
"message": "If an account exists with this email, a reset link has been sent",
})
return
}
// Generate token
tokenBytes := make([]byte, 32)
rand.Read(tokenBytes)
token := hex.EncodeToString(tokenBytes)
expiration := time.Now().Add(24 * time.Hour)
// Store token
_, err = s.pool.Exec(ctx,
`UPDATE res_users SET signup_token = $1, signup_expiration = $2 WHERE id = $3`,
token, expiration, uid)
if err != nil {
log.Printf("portal reset: store token error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
// Send email with reset link
smtpCfg := tools.LoadSMTPConfig()
resetURL := fmt.Sprintf("/web/portal/reset_password?token=%s", token)
emailBody := fmt.Sprintf(`<html><body>
<p>A password reset was requested for your account.</p>
<p>Click the link below to set a new password:</p>
<p><a href="%s">Reset Password</a></p>
<p>This link expires in 24 hours.</p>
<p>If you did not request this, you can ignore this email.</p>
</body></html>`, resetURL)
if err := tools.SendEmail(smtpCfg, email, "Password Reset", emailBody); err != nil {
log.Printf("portal reset: send email error: %v", err)
// Don't expose email sending errors to the user
}
writePortalJSON(w, map[string]interface{}{
"success": true,
"message": "If an account exists with this email, a reset link has been sent",
})
}
// handleResetWithToken validates the token and sets the new password.
func (s *Server) handleResetWithToken(w http.ResponseWriter, ctx context.Context, token, password string) {
if len(password) < 8 {
writePortalError(w, http.StatusBadRequest, "Password must be at least 8 characters")
return
}
// Look up user by token
var uid int64
var expiration time.Time
err := s.pool.QueryRow(ctx,
`SELECT id, signup_expiration FROM res_users
WHERE signup_token = $1 AND active = true`, token).Scan(&uid, &expiration)
if err != nil {
writePortalError(w, http.StatusBadRequest, "Invalid or expired reset token")
return
}
// Check expiration
if time.Now().After(expiration) {
// Clear expired token
s.pool.Exec(ctx,
`UPDATE res_users SET signup_token = NULL, signup_expiration = NULL WHERE id = $1`, uid)
writePortalError(w, http.StatusBadRequest, "Reset token has expired")
return
}
// Hash new password
hashedPw, err := tools.HashPassword(password)
if err != nil {
log.Printf("portal reset: hash password error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Internal error")
return
}
// Update password and clear token
_, err = s.pool.Exec(ctx,
`UPDATE res_users SET password = $1, signup_token = NULL, signup_expiration = NULL
WHERE id = $2`, hashedPw, uid)
if err != nil {
log.Printf("portal reset: update password error: %v", err)
writePortalError(w, http.StatusInternalServerError, "Failed to reset password")
return
}
log.Printf("portal reset: password reset for uid=%d", uid)
writePortalJSON(w, map[string]interface{}{
"success": true,
"message": "Password has been reset successfully",
})
}

View File

@@ -9,6 +9,7 @@ import (
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
@@ -35,6 +36,8 @@ type Server struct {
// all JS files (except module_loader.js) plus the XML template bundle,
// served as a single file to avoid hundreds of individual HTTP requests.
jsBundle string
bus *Bus // Message bus for Discuss long-polling
}
// New creates a new server instance.
@@ -128,6 +131,17 @@ func (s *Server) registerRoutes() {
// CSV export
s.mux.HandleFunc("/web/export/csv", s.handleExportCSV)
s.mux.HandleFunc("/web/export/xlsx", s.handleExportXLSX)
// Import
s.mux.HandleFunc("/web/import/csv", s.handleImportCSV)
// Post-setup wizard
s.mux.HandleFunc("/web/setup/wizard", s.handleSetupWizard)
s.mux.HandleFunc("/web/setup/wizard/save", s.handleSetupWizardSave)
// Bank statement import
s.mux.HandleFunc("/web/bank_statement/import", s.handleBankStatementImport)
// Reports (HTML and PDF report rendering)
s.mux.HandleFunc("/report/", s.handleReport)
@@ -137,10 +151,16 @@ func (s *Server) registerRoutes() {
// Logout & Account
s.mux.HandleFunc("/web/session/logout", s.handleLogout)
s.mux.HandleFunc("/web/session/account", s.handleSessionAccount)
s.mux.HandleFunc("/web/session/switch_company", s.handleSwitchCompany)
// Health check
s.mux.HandleFunc("/health", s.handleHealth)
// Portal routes (external user access)
s.registerPortalRoutes()
s.registerPortalSignupRoutes()
s.registerBusRoutes()
// Static files (catch-all for /<addon>/static/...)
// NOTE: must be last since it's a broad pattern
}
@@ -255,13 +275,14 @@ func (s *Server) handleCallKW(w http.ResponseWriter, r *http.Request) {
return
}
// Extract UID from session, default to 1 (admin) if no session
uid := int64(1)
companyID := int64(1)
if sess := GetSession(r); sess != nil {
uid = sess.UID
companyID = sess.CompanyID
// Extract UID from session — reject if no session (defense in depth)
sess := GetSession(r)
if sess == nil {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: 100, Message: "Session expired"})
return
}
uid := sess.UID
companyID := sess.CompanyID
// Create environment for this request
env, err := orm.NewEnvironment(r.Context(), orm.EnvConfig{
@@ -294,6 +315,36 @@ func (s *Server) handleCallKW(w http.ResponseWriter, r *http.Request) {
s.writeJSONRPC(w, req.ID, result, nil)
}
// sensitiveFields lists fields that only admin (uid=1) may write to.
// Prevents privilege escalation via field manipulation.
var sensitiveFields = map[string]map[string]bool{
"ir.cron": {"user_id": true, "model_name": true, "method_name": true},
"ir.model.access": {"group_id": true, "perm_read": true, "perm_write": true, "perm_create": true, "perm_unlink": true},
"ir.rule": {"domain_force": true, "groups": true, "perm_read": true, "perm_write": true, "perm_create": true, "perm_unlink": true},
"res.users": {"groups_id": true},
"res.groups": {"users": true},
}
// checkSensitiveFields blocks non-admin users from writing protected fields.
func checkSensitiveFields(env *orm.Environment, model string, vals orm.Values) *RPCError {
if env.UID() == 1 || env.IsSuperuser() {
return nil
}
fields, ok := sensitiveFields[model]
if !ok {
return nil
}
for field := range vals {
if fields[field] {
return &RPCError{
Code: 403,
Message: fmt.Sprintf("Access Denied: field %q on %s is admin-only", field, model),
}
}
}
return nil
}
// checkAccess verifies the current user has permission for the operation.
// Mirrors: odoo/addons/base/models/ir_model.py IrModelAccess.check()
func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCError {
@@ -317,8 +368,22 @@ func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCErr
`SELECT COUNT(*) FROM ir_model_access a
JOIN ir_model m ON m.id = a.model_id
WHERE m.model = $1`, model).Scan(&count)
if err != nil || count == 0 {
return nil // No ACLs definedopen access (like Odoo superuser mode)
if err != nil {
// DB errordeny access (fail-closed)
log.Printf("access: DB error checking ACL for model %s: %v", model, err)
return &RPCError{
Code: 403,
Message: fmt.Sprintf("Access Denied: %s on %s (internal error)", method, model),
}
}
if count == 0 {
// No ACL rules defined for this model → deny (fail-closed).
// All models should have ACL seed data via seedACLRules().
log.Printf("access: no ACL for model %s, denying (fail-closed)", model)
return &RPCError{
Code: 403,
Message: fmt.Sprintf("Access Denied: no ACL rules for %s", model),
}
}
// Check if user's groups grant permission
@@ -334,7 +399,11 @@ func (s *Server) checkAccess(env *orm.Environment, model, method string) *RPCErr
AND (a.group_id IS NULL OR gu.res_users_id = $2)
)`, perm), model, env.UID()).Scan(&granted)
if err != nil {
return nil // On error, allow (fail-open for now)
log.Printf("access: DB error checking ACL grant for model %s: %v", model, err)
return &RPCError{
Code: 403,
Message: fmt.Sprintf("Access Denied: %s on %s (internal error)", method, model),
}
}
if !granted {
return &RPCError{
@@ -379,10 +448,57 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
switch params.Method {
case "has_group":
// Always return true for admin user, stub for now
return true, nil
// Check if current user belongs to the given group.
// Mirrors: odoo/orm/models.py BaseModel.user_has_groups()
groupXMLID := ""
if len(params.Args) > 0 {
groupXMLID, _ = params.Args[0].(string)
}
if groupXMLID == "" {
return false, nil
}
// Admin always has all groups
if env.UID() == 1 {
return true, nil
}
// Parse "module.xml_id" format
parts := strings.SplitN(groupXMLID, ".", 2)
if len(parts) != 2 {
return false, nil
}
// Query: does user belong to this group?
var exists bool
err := env.Tx().QueryRow(env.Ctx(),
`SELECT EXISTS(
SELECT 1 FROM res_groups_res_users_rel gur
JOIN ir_model_data imd ON imd.res_id = gur.res_groups_id AND imd.model = 'res.groups'
WHERE gur.res_users_id = $1 AND imd.module = $2 AND imd.name = $3
)`, env.UID(), parts[0], parts[1]).Scan(&exists)
if err != nil {
return false, nil
}
return exists, nil
case "check_access_rights":
// Check if current user has the given access right on this model.
// Mirrors: odoo/orm/models.py BaseModel.check_access_rights()
operation := "read"
if len(params.Args) > 0 {
if op, ok := params.Args[0].(string); ok {
operation = op
}
}
raiseException := true
if v, ok := params.KW["raise_exception"].(bool); ok {
raiseException = v
}
accessErr := s.checkAccess(env, params.Model, operation)
if accessErr != nil {
if raiseException {
return nil, accessErr
}
return false, nil
}
return true, nil
case "fields_get":
@@ -404,6 +520,11 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
vals := parseValuesAt(params.Args, 1)
spec, _ := params.KW["specification"].(map[string]interface{})
// Field-level access control
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
return nil, err
}
if len(ids) > 0 && ids[0] > 0 {
// Update existing record(s)
err := rs.Browse(ids...).Write(vals)
@@ -513,6 +634,9 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
case "create":
vals := parseValues(params.Args)
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
return nil, err
}
record, err := rs.Create(vals)
if err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
@@ -522,6 +646,9 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
case "write":
ids := parseIDs(params.Args)
vals := parseValuesAt(params.Args, 1)
if err := checkSensitiveFields(env, params.Model, vals); err != nil {
return nil, err
}
err := rs.Browse(ids...).Write(vals)
if err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
@@ -645,9 +772,33 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
}, nil
case "get_formview_id":
return false, nil
// Return the default form view ID for this model.
// Mirrors: odoo/orm/models.py BaseModel.get_formview_id()
var viewID *int64
err := env.Tx().QueryRow(env.Ctx(),
`SELECT id FROM ir_ui_view
WHERE model = $1 AND type = 'form' AND active = true
ORDER BY priority, id LIMIT 1`,
params.Model).Scan(&viewID)
if err != nil || viewID == nil {
return false, nil
}
return *viewID, nil
case "action_get":
// Try registered method first (e.g. res.users has its own action_get).
// Mirrors: odoo/addons/base/models/res_users.py action_get()
model := orm.Registry.Get(params.Model)
if model != nil && model.Methods != nil {
if method, ok := model.Methods["action_get"]; ok {
ids := parseIDs(params.Args)
result, err := method(rs.Browse(ids...), params.Args[1:]...)
if err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
}
return result, nil
}
}
return false, nil
case "name_create":
@@ -665,10 +816,48 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
return []interface{}{created.ID(), nameStr}, nil
case "read_progress_bar":
return map[string]interface{}{}, nil
return s.handleReadProgressBar(rs, params)
case "activity_format":
return []interface{}{}, nil
ids := parseIDs(params.Args)
if len(ids) == 0 {
return []interface{}{}, nil
}
// Search activities for this model/record
actRS := env.Model("mail.activity")
var allActivities []orm.Values
for _, id := range ids {
domain := orm.And(
orm.Leaf("res_model", "=", params.Model),
orm.Leaf("res_id", "=", id),
orm.Leaf("done", "=", false),
)
found, err := actRS.Search(domain, orm.SearchOpts{Order: "date_deadline"})
if err != nil || found.IsEmpty() {
continue
}
records, err := found.Read([]string{"id", "res_model", "res_id", "activity_type_id", "summary", "note", "date_deadline", "user_id", "state"})
if err != nil {
continue
}
allActivities = append(allActivities, records...)
}
if allActivities == nil {
return []interface{}{}, nil
}
// Format M2O fields
actSpec := map[string]interface{}{
"activity_type_id": map[string]interface{}{},
"user_id": map[string]interface{}{},
}
formatM2OFields(env, "mail.activity", allActivities, actSpec)
formatDateFields("mail.activity", allActivities)
normalizeNullFields("mail.activity", allActivities)
actResult := make([]interface{}, len(allActivities))
for i, a := range allActivities {
actResult[i] = a
}
return actResult, nil
case "action_archive":
ids := parseIDs(params.Args)
@@ -697,6 +886,199 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
}
return created.ID(), nil
case "web_resequence":
// Resequence records by their IDs (drag&drop reordering).
// Mirrors: odoo/addons/web/models/models.py web_resequence()
ids := parseIDs(params.Args)
if len(ids) == 0 {
return []orm.Values{}, nil
}
// Parse field_name (default "sequence")
fieldName := "sequence"
if v, ok := params.KW["field_name"].(string); ok {
fieldName = v
}
// Parse offset (default 0)
offset := 0
if v, ok := params.KW["offset"].(float64); ok {
offset = int(v)
}
// Check if field exists on the model
model := orm.Registry.Get(params.Model)
if model == nil || model.GetField(fieldName) == nil {
return []orm.Values{}, nil
}
// Update sequence for each record in order
for i, id := range ids {
if err := rs.Browse(id).Write(orm.Values{fieldName: offset + i}); err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
}
}
// Return records via web_read
spec, _ := params.KW["specification"].(map[string]interface{})
readParams := CallKWParams{
Model: params.Model,
Method: "web_read",
Args: []interface{}{ids},
KW: map[string]interface{}{"specification": spec},
}
return handleWebRead(env, params.Model, readParams)
case "message_post":
// Post a message on the record's chatter.
// Mirrors: odoo/addons/mail/models/mail_thread.py message_post()
ids := parseIDs(params.Args)
if len(ids) == 0 {
return false, nil
}
body, _ := params.KW["body"].(string)
messageType := "comment"
if v, _ := params.KW["message_type"].(string); v != "" {
messageType = v
}
// Get author from current user's partner_id
var authorID int64
if err := env.Tx().QueryRow(env.Ctx(),
`SELECT partner_id FROM res_users WHERE id = $1`, env.UID(),
).Scan(&authorID); err != nil {
log.Printf("warning: message_post author lookup failed: %v", err)
}
// Create mail.message linked to the current model/record
var msgID int64
err := env.Tx().QueryRow(env.Ctx(),
`INSERT INTO mail_message (model, res_id, body, message_type, author_id, date, create_uid, write_uid, create_date, write_date)
VALUES ($1, $2, $3, $4, $5, NOW(), $6, $6, NOW(), NOW())
RETURNING id`,
params.Model, ids[0], body, messageType, authorID, env.UID(),
).Scan(&msgID)
if err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
}
return msgID, nil
case "_message_get_thread":
// Get messages for a record's chatter.
// Mirrors: odoo/addons/mail/models/mail_thread.py
ids := parseIDs(params.Args)
if len(ids) == 0 {
return []interface{}{}, nil
}
rows, err := env.Tx().Query(env.Ctx(),
`SELECT m.id, m.body, m.message_type, m.date,
m.author_id, COALESCE(p.name, ''),
COALESCE(m.subject, ''), COALESCE(m.email_from, '')
FROM mail_message m
LEFT JOIN res_partner p ON p.id = m.author_id
WHERE m.model = $1 AND m.res_id = $2
ORDER BY m.id DESC`,
params.Model, ids[0],
)
if err != nil {
return nil, &RPCError{Code: -32000, Message: err.Error()}
}
defer rows.Close()
var messages []map[string]interface{}
for rows.Next() {
var id int64
var body, msgType, subject, emailFrom string
var date interface{}
var authorID int64
var authorName string
if scanErr := rows.Scan(&id, &body, &msgType, &date, &authorID, &authorName, &subject, &emailFrom); scanErr != nil {
continue
}
msg := map[string]interface{}{
"id": id,
"body": body,
"message_type": msgType,
"date": date,
"subject": subject,
"email_from": emailFrom,
}
if authorID > 0 {
msg["author_id"] = []interface{}{authorID, authorName}
} else {
msg["author_id"] = false
}
messages = append(messages, msg)
}
if messages == nil {
messages = []map[string]interface{}{}
}
return messages, nil
case "read_followers":
ids := parseIDs(params.Args)
if len(ids) == 0 {
return []interface{}{}, nil
}
// Search followers for this model/record
followerRS := env.Model("mail.followers")
domain := orm.And(
orm.Leaf("res_model", "=", params.Model),
orm.Leaf("res_id", "in", ids),
)
found, err := followerRS.Search(domain, orm.SearchOpts{Limit: 100})
if err != nil || found.IsEmpty() {
return []interface{}{}, nil
}
followerRecords, err := found.Read([]string{"id", "res_model", "res_id", "partner_id"})
if err != nil {
return []interface{}{}, nil
}
followerSpec := map[string]interface{}{"partner_id": map[string]interface{}{}}
formatM2OFields(env, "mail.followers", followerRecords, followerSpec)
normalizeNullFields("mail.followers", followerRecords)
followerResult := make([]interface{}, len(followerRecords))
for i, r := range followerRecords {
followerResult[i] = r
}
return followerResult, nil
case "get_activity_data":
// Return activity summary data for records.
// Mirrors: odoo/addons/mail/models/mail_activity_mixin.py
emptyResult := map[string]interface{}{
"activity_types": []interface{}{},
"activity_res_ids": map[string]interface{}{},
"grouped_activities": map[string]interface{}{},
}
ids := parseIDs(params.Args)
if len(ids) == 0 {
return emptyResult, nil
}
// Get activity types
typeRS := env.Model("mail.activity.type")
types, err := typeRS.Search(nil, orm.SearchOpts{Order: "sequence, id"})
if err != nil || types.IsEmpty() {
return emptyResult, nil
}
typeRecords, _ := types.Read([]string{"id", "name"})
typeList := make([]interface{}, len(typeRecords))
for i, t := range typeRecords {
typeList[i] = t
}
return map[string]interface{}{
"activity_types": typeList,
"activity_res_ids": map[string]interface{}{},
"grouped_activities": map[string]interface{}{},
}, nil
default:
// Try registered business methods on the model.
// Mirrors: odoo/service/model.py call_kw() + odoo/addons/web/controllers/dataset.py call_button()
@@ -732,6 +1114,58 @@ func (s *Server) dispatchORM(env *orm.Environment, params CallKWParams) (interfa
// --- Session / Auth Endpoints ---
// loginAttemptInfo tracks login attempts for rate limiting.
type loginAttemptInfo struct {
Count int
LastTime time.Time
}
var (
loginAttempts = make(map[string]loginAttemptInfo)
loginAttemptsMu sync.Mutex
)
// checkLoginRateLimit returns false if the login is rate-limited (too many attempts).
func (s *Server) checkLoginRateLimit(login string) bool {
loginAttemptsMu.Lock()
defer loginAttemptsMu.Unlock()
now := time.Now()
// Periodic cleanup: evict stale entries (>15 min old) to prevent unbounded growth
if len(loginAttempts) > 100 {
for k, v := range loginAttempts {
if now.Sub(v.LastTime) > 15*time.Minute {
delete(loginAttempts, k)
}
}
}
info := loginAttempts[login]
// Reset after 15 minutes
if now.Sub(info.LastTime) > 15*time.Minute {
info = loginAttemptInfo{}
}
// Max 10 attempts per 15 minutes
if info.Count >= 10 {
return false // Rate limited
}
info.Count++
info.LastTime = now
loginAttempts[login] = info
return true
}
// resetLoginRateLimit clears the rate limit counter on successful login.
func (s *Server) resetLoginRateLimit(login string) {
loginAttemptsMu.Lock()
defer loginAttemptsMu.Unlock()
delete(loginAttempts, login)
}
func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
@@ -754,6 +1188,14 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
return
}
// Rate limit login attempts
if !s.checkLoginRateLimit(params.Login) {
s.writeJSONRPC(w, req.ID, nil, &RPCError{
Code: 429, Message: "Too many login attempts. Please try again later.",
})
return
}
// Query user by login
var uid int64
var companyID int64
@@ -776,16 +1218,40 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
return
}
// Check password (support both bcrypt and plaintext for migration)
if !tools.CheckPassword(hashedPw, params.Password) && hashedPw != params.Password {
// Check password (bcrypt only — no plaintext fallback)
if !tools.CheckPassword(hashedPw, params.Password) {
s.writeJSONRPC(w, req.ID, nil, &RPCError{
Code: 100, Message: "Access Denied: invalid login or password",
})
return
}
// Successful login reset rate limiter
s.resetLoginRateLimit(params.Login)
// Query allowed companies for the user
allowedCompanyIDs := []int64{companyID}
rows, err := s.pool.Query(r.Context(),
`SELECT DISTINCT c.id FROM res_company c
WHERE c.active = true
ORDER BY c.id`)
if err == nil {
defer rows.Close()
var ids []int64
for rows.Next() {
var cid int64
if rows.Scan(&cid) == nil {
ids = append(ids, cid)
}
}
if len(ids) > 0 {
allowedCompanyIDs = ids
}
}
// Create session
sess := s.sessions.New(uid, companyID, params.Login)
sess.AllowedCompanyIDs = allowedCompanyIDs
// Set session cookie
http.SetCookie(w, &http.Cookie{
@@ -793,6 +1259,7 @@ func (s *Server) handleAuthenticate(w http.ResponseWriter, r *http.Request) {
Value: sess.ID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
@@ -857,6 +1324,7 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: true,
})
http.Redirect(w, r, "/web/login", http.StatusFound)
}

View File

@@ -13,12 +13,14 @@ import (
// Session represents an authenticated user session.
type Session struct {
ID string
UID int64
CompanyID int64
Login string
CreatedAt time.Time
LastActivity time.Time
ID string
UID int64
CompanyID int64
AllowedCompanyIDs []int64
Login string
CSRFToken string
CreatedAt time.Time
LastActivity time.Time
}
// SessionStore is a session store with an in-memory cache backed by PostgreSQL.
@@ -47,10 +49,15 @@ func InitSessionTable(ctx context.Context, pool *pgxpool.Pool) error {
uid INT8 NOT NULL,
company_id INT8 NOT NULL,
login VARCHAR(255),
csrf_token VARCHAR(64) DEFAULT '',
created_at TIMESTAMP DEFAULT NOW(),
last_seen TIMESTAMP DEFAULT NOW()
)
`)
if err == nil {
// Add csrf_token column if table already exists without it
pool.Exec(ctx, `ALTER TABLE sessions ADD COLUMN IF NOT EXISTS csrf_token VARCHAR(64) DEFAULT ''`)
}
if err != nil {
return err
}
@@ -67,6 +74,7 @@ func (s *SessionStore) New(uid, companyID int64, login string) *Session {
UID: uid,
CompanyID: companyID,
Login: login,
CSRFToken: generateToken(),
CreatedAt: now,
LastActivity: now,
}
@@ -81,10 +89,10 @@ func (s *SessionStore) New(uid, companyID int64, login string) *Session {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := s.pool.Exec(ctx,
`INSERT INTO sessions (id, uid, company_id, login, created_at, last_seen)
VALUES ($1, $2, $3, $4, $5, $6)
`INSERT INTO sessions (id, uid, company_id, login, csrf_token, created_at, last_seen)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO NOTHING`,
token, uid, companyID, login, now, now)
token, uid, companyID, login, sess.CSRFToken, now, now)
if err != nil {
log.Printf("session: failed to persist session to DB: %v", err)
}
@@ -106,20 +114,23 @@ func (s *SessionStore) Get(id string) *Session {
s.Delete(id)
return nil
}
// Update last activity
now := time.Now()
needsDBUpdate := time.Since(sess.LastActivity) > 30*time.Second
// Update last activity in memory
s.mu.Lock()
sess.LastActivity = now
s.mu.Unlock()
// Update last_seen in DB asynchronously
if s.pool != nil {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.pool.Exec(ctx,
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id)
}()
// Throttle DB writes: only persist every 30s to avoid per-request overhead
if needsDBUpdate && s.pool != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := s.pool.Exec(ctx,
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id); err != nil {
log.Printf("session: failed to update last_seen in DB: %v", err)
}
}
return sess
@@ -134,14 +145,20 @@ func (s *SessionStore) Get(id string) *Session {
defer cancel()
sess = &Session{}
var csrfToken string
err := s.pool.QueryRow(ctx,
`SELECT id, uid, company_id, login, created_at, last_seen
`SELECT id, uid, company_id, login, COALESCE(csrf_token, ''), created_at, last_seen
FROM sessions WHERE id = $1`, id).Scan(
&sess.ID, &sess.UID, &sess.CompanyID, &sess.Login,
&sess.ID, &sess.UID, &sess.CompanyID, &sess.Login, &csrfToken,
&sess.CreatedAt, &sess.LastActivity)
if err != nil {
return nil
}
if csrfToken != "" {
sess.CSRFToken = csrfToken
} else {
sess.CSRFToken = generateToken()
}
// Check TTL
if time.Since(sess.LastActivity) > s.ttl {
@@ -149,18 +166,18 @@ func (s *SessionStore) Get(id string) *Session {
return nil
}
// Update last activity
// Update last activity and add to memory cache
now := time.Now()
sess.LastActivity = now
// Add to memory cache
s.mu.Lock()
sess.LastActivity = now
s.sessions[id] = sess
s.mu.Unlock()
// Update last_seen in DB
s.pool.Exec(ctx,
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id)
if _, err := s.pool.Exec(ctx,
`UPDATE sessions SET last_seen = $1 WHERE id = $2`, now, id); err != nil {
log.Printf("session: failed to update last_seen in DB: %v", err)
}
return sess
}

View File

@@ -6,14 +6,17 @@ import (
"fmt"
"log"
"net/http"
"os"
"regexp"
"strings"
"sync/atomic"
"time"
"odoo-go/pkg/service"
"odoo-go/pkg/tools"
)
var dbnamePattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]+$`)
// isSetupNeeded checks if the current database has been initialized.
@@ -55,6 +58,16 @@ func (s *Server) handleDatabaseCreate(w http.ResponseWriter, r *http.Request) {
return
}
// Validate master password (default: "admin", configurable via ODOO_MASTER_PASSWORD env)
masterPw := os.Getenv("ODOO_MASTER_PASSWORD")
if masterPw == "" {
masterPw = "admin"
}
if params.MasterPwd != masterPw {
writeJSON(w, map[string]string{"error": "Invalid master password"})
return
}
// Validate
if params.Login == "" || params.Password == "" {
writeJSON(w, map[string]string{"error": "Email and password are required"})
@@ -111,7 +124,10 @@ func (s *Server) handleDatabaseCreate(w http.ResponseWriter, r *http.Request) {
domain := parts[1]
domainParts := strings.Split(domain, ".")
if len(domainParts) > 0 {
companyName = strings.Title(domainParts[0])
name := domainParts[0]
if len(name) > 0 {
companyName = strings.ToUpper(name[:1]) + name[1:]
}
}
}
}
@@ -175,6 +191,195 @@ func writeJSON(w http.ResponseWriter, v interface{}) {
json.NewEncoder(w).Encode(v)
}
// postSetupDone caches the result of isPostSetupNeeded to avoid a DB query on every request.
var postSetupDone atomic.Bool
// isPostSetupNeeded checks if the company still has default values (needs configuration).
func (s *Server) isPostSetupNeeded() bool {
if postSetupDone.Load() {
return false
}
var name string
err := s.pool.QueryRow(context.Background(),
`SELECT COALESCE(name, '') FROM res_company WHERE id = 1`).Scan(&name)
if err != nil {
return false
}
needed := name == "" || name == "My Company" || strings.HasPrefix(name, "My ")
if !needed {
postSetupDone.Store(true)
}
return needed
}
// handleSetupWizard serves the post-setup configuration wizard.
// Shown after first login when the company has not been configured yet.
// Mirrors: odoo/addons/base_setup/views/res_config_settings_views.xml
func (s *Server) handleSetupWizard(w http.ResponseWriter, r *http.Request) {
sess := GetSession(r)
if sess == nil {
http.Redirect(w, r, "/web/login", http.StatusFound)
return
}
// Load current company data
var companyName, street, city, zip, phone, email, website, vat string
var countryID int64
s.pool.QueryRow(context.Background(),
`SELECT COALESCE(name,''), COALESCE(street,''), COALESCE(city,''), COALESCE(zip,''),
COALESCE(phone,''), COALESCE(email,''), COALESCE(website,''), COALESCE(vat,''),
COALESCE(country_id, 0)
FROM res_company WHERE id = $1`, sess.CompanyID,
).Scan(&companyName, &street, &city, &zip, &phone, &email, &website, &vat, &countryID)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
esc := htmlEscape
fmt.Fprintf(w, setupWizardHTML,
esc(companyName), esc(street), esc(city), esc(zip), esc(phone), esc(email), esc(website), esc(vat))
}
// handleSetupWizardSave saves the post-setup wizard data.
func (s *Server) handleSetupWizardSave(w http.ResponseWriter, r *http.Request) {
sess := GetSession(r)
if sess == nil {
writeJSON(w, map[string]string{"error": "Not authenticated"})
return
}
var params struct {
CompanyName string `json:"company_name"`
Street string `json:"street"`
City string `json:"city"`
Zip string `json:"zip"`
Phone string `json:"phone"`
Email string `json:"email"`
Website string `json:"website"`
Vat string `json:"vat"`
}
if err := json.NewDecoder(r.Body).Decode(&params); err != nil {
writeJSON(w, map[string]string{"error": "Invalid request"})
return
}
if params.CompanyName == "" {
writeJSON(w, map[string]string{"error": "Company name is required"})
return
}
_, err := s.pool.Exec(context.Background(),
`UPDATE res_company SET name=$1, street=$2, city=$3, zip=$4, phone=$5, email=$6, website=$7, vat=$8
WHERE id = $9`,
params.CompanyName, params.Street, params.City, params.Zip,
params.Phone, params.Email, params.Website, params.Vat, sess.CompanyID)
if err != nil {
writeJSON(w, map[string]string{"error": fmt.Sprintf("Save error: %v", err)})
return
}
// Also update the partner linked to the company
s.pool.Exec(context.Background(),
`UPDATE res_partner SET name=$1, street=$2, city=$3, zip=$4, phone=$5, email=$6, website=$7, vat=$8
WHERE id = (SELECT partner_id FROM res_company WHERE id = $9)`,
params.CompanyName, params.Street, params.City, params.Zip,
params.Phone, params.Email, params.Website, params.Vat, sess.CompanyID)
postSetupDone.Store(true) // Mark setup as done so we don't redirect again
writeJSON(w, map[string]interface{}{"status": "ok", "redirect": "/odoo"})
}
var setupWizardHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1"/>
<title>Setup — Configure Your Company</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: #f0eeee; display: flex; align-items: center; justify-content: center; min-height: 100vh; }
.wizard { background: white; padding: 40px; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);
width: 100%%; max-width: 560px; }
.wizard h1 { color: #71639e; margin-bottom: 6px; font-size: 24px; }
.wizard .subtitle { color: #666; margin-bottom: 24px; font-size: 14px; }
.wizard label { display: block; margin-bottom: 4px; font-weight: 500; color: #555; font-size: 13px; }
.wizard input { width: 100%%; padding: 9px 12px; border: 1px solid #ddd; border-radius: 4px;
font-size: 14px; margin-bottom: 14px; }
.wizard input:focus { outline: none; border-color: #71639e; box-shadow: 0 0 0 2px rgba(113,99,158,0.2); }
.wizard button { width: 100%%; padding: 14px; background: #71639e; color: white; border: none;
border-radius: 4px; font-size: 16px; cursor: pointer; margin-top: 10px; }
.wizard button:hover { background: #5f5387; }
.wizard .skip { text-align: center; margin-top: 12px; }
.wizard .skip a { color: #999; text-decoration: none; font-size: 13px; }
.wizard .skip a:hover { color: #666; }
.row { display: flex; gap: 12px; }
.row > div { flex: 1; }
.error { color: #dc3545; margin-bottom: 12px; display: none; text-align: center; font-size: 14px; }
</style>
</head>
<body>
<div class="wizard">
<h1>Configure Your Company</h1>
<p class="subtitle">Set up your company information</p>
<div id="error" class="error"></div>
<form id="wizardForm">
<label>Company Name *</label>
<input type="text" id="company_name" value="%s" required/>
<label>Street</label>
<input type="text" id="street" value="%s"/>
<div class="row">
<div><label>City</label><input type="text" id="city" value="%s"/></div>
<div><label>ZIP</label><input type="text" id="zip" value="%s"/></div>
</div>
<div class="row">
<div><label>Phone</label><input type="tel" id="phone" value="%s"/></div>
<div><label>Email</label><input type="email" id="email" value="%s"/></div>
</div>
<label>Website</label>
<input type="url" id="website" value="%s" placeholder="https://"/>
<label>Tax ID / VAT</label>
<input type="text" id="vat" value="%s"/>
<button type="submit">Save & Continue</button>
</form>
<div class="skip"><a href="/odoo">Skip for now</a></div>
</div>
<script>
document.getElementById('wizardForm').addEventListener('submit', function(e) {
e.preventDefault();
fetch('/web/setup/wizard/save', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
company_name: document.getElementById('company_name').value,
street: document.getElementById('street').value,
city: document.getElementById('city').value,
zip: document.getElementById('zip').value,
phone: document.getElementById('phone').value,
email: document.getElementById('email').value,
website: document.getElementById('website').value,
vat: document.getElementById('vat').value
})
})
.then(function(r) { return r.json(); })
.then(function(result) {
if (result.error) {
var el = document.getElementById('error');
el.textContent = result.error;
el.style.display = 'block';
} else {
window.location.href = result.redirect || '/odoo';
}
});
});
</script>
</body>
</html>`
// --- Database Manager HTML ---
// Mirrors: odoo/addons/web/static/src/public/database_manager.create_form.qweb.html
var databaseManagerHTML = `<!DOCTYPE html>

View File

@@ -43,8 +43,9 @@ func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) {
addonName := parts[0]
filePath := parts[2]
// Security: prevent directory traversal
if strings.Contains(filePath, "..") {
// Security: prevent directory traversal in both addonName and filePath
if strings.Contains(filePath, "..") || strings.Contains(addonName, "..") ||
strings.Contains(addonName, "/") || strings.Contains(addonName, "\\") {
http.NotFound(w, r)
return
}

View File

@@ -88,7 +88,7 @@ func TestExtractImports(t *testing.T) {
content := `import { Foo, Bar } from "@web/core/foo";
import { Baz as Qux } from "@web/core/baz";
const x = 1;`
deps, requires, clean := extractImports(content)
deps, requires, clean := extractImports("test.module", content)
if len(deps) != 2 {
t.Fatalf("expected 2 deps, got %d: %v", len(deps), deps)
@@ -120,7 +120,7 @@ const x = 1;`
t.Run("default import", func(t *testing.T) {
content := `import Foo from "@web/core/foo";`
deps, requires, _ := extractImports(content)
deps, requires, _ := extractImports("test.module", content)
if len(deps) != 1 || deps[0] != "@web/core/foo" {
t.Errorf("deps = %v, want [@web/core/foo]", deps)
@@ -132,7 +132,7 @@ const x = 1;`
t.Run("namespace import", func(t *testing.T) {
content := `import * as utils from "@web/core/utils";`
deps, requires, _ := extractImports(content)
deps, requires, _ := extractImports("test.module", content)
if len(deps) != 1 || deps[0] != "@web/core/utils" {
t.Errorf("deps = %v, want [@web/core/utils]", deps)
@@ -144,7 +144,7 @@ const x = 1;`
t.Run("side-effect import", func(t *testing.T) {
content := `import "@web/core/setup";`
deps, requires, _ := extractImports(content)
deps, requires, _ := extractImports("test.module", content)
if len(deps) != 1 || deps[0] != "@web/core/setup" {
t.Errorf("deps = %v, want [@web/core/setup]", deps)
@@ -157,7 +157,7 @@ const x = 1;`
t.Run("dedup deps", func(t *testing.T) {
content := `import { Foo } from "@web/core/foo";
import { Bar } from "@web/core/foo";`
deps, _, _ := extractImports(content)
deps, _, _ := extractImports("test.module", content)
if len(deps) != 1 {
t.Errorf("expected deduped deps, got %v", deps)
@@ -167,7 +167,7 @@ import { Bar } from "@web/core/foo";`
func TestTransformExports(t *testing.T) {
t.Run("export class", func(t *testing.T) {
got := transformExports("export class Foo extends Bar {")
got, _ := transformExports("export class Foo extends Bar {")
want := "const Foo = __exports.Foo = class Foo extends Bar {"
if got != want {
t.Errorf("got %q, want %q", got, want)
@@ -175,15 +175,18 @@ func TestTransformExports(t *testing.T) {
})
t.Run("export function", func(t *testing.T) {
got := transformExports("export function doSomething(a, b) {")
want := `__exports.doSomething = function doSomething(a, b) {`
got, deferred := transformExports("export function doSomething(a, b) {")
want := `function doSomething(a, b) {`
if got != want {
t.Errorf("got %q, want %q", got, want)
}
if len(deferred) != 1 || deferred[0] != "doSomething" {
t.Errorf("deferred = %v, want [doSomething]", deferred)
}
})
t.Run("export const", func(t *testing.T) {
got := transformExports("export const MAX_SIZE = 100;")
got, _ := transformExports("export const MAX_SIZE = 100;")
want := "const MAX_SIZE = __exports.MAX_SIZE = 100;"
if got != want {
t.Errorf("got %q, want %q", got, want)
@@ -191,7 +194,7 @@ func TestTransformExports(t *testing.T) {
})
t.Run("export let", func(t *testing.T) {
got := transformExports("export let counter = 0;")
got, _ := transformExports("export let counter = 0;")
want := "let counter = __exports.counter = 0;"
if got != want {
t.Errorf("got %q, want %q", got, want)
@@ -199,7 +202,7 @@ func TestTransformExports(t *testing.T) {
})
t.Run("export default", func(t *testing.T) {
got := transformExports("export default Foo;")
got, _ := transformExports("export default Foo;")
want := `__exports[Symbol.for("default")] = Foo;`
if got != want {
t.Errorf("got %q, want %q", got, want)
@@ -207,7 +210,7 @@ func TestTransformExports(t *testing.T) {
})
t.Run("export named", func(t *testing.T) {
got := transformExports("export { Foo, Bar };")
got, _ := transformExports("export { Foo, Bar };")
if !strings.Contains(got, "__exports.Foo = Foo;") {
t.Errorf("missing Foo export in: %s", got)
}
@@ -217,7 +220,7 @@ func TestTransformExports(t *testing.T) {
})
t.Run("export named with alias", func(t *testing.T) {
got := transformExports("export { Foo as default };")
got, _ := transformExports("export { Foo as default };")
if !strings.Contains(got, "__exports.default = Foo;") {
t.Errorf("missing aliased export in: %s", got)
}

View File

@@ -20,12 +20,27 @@ func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
return
}
// Parse multipart form (max 128MB)
if err := r.ParseMultipartForm(128 << 20); err != nil {
// Limit upload size to 50MB
r.Body = http.MaxBytesReader(w, r.Body, 50<<20)
// Parse multipart form (max 50MB)
if err := r.ParseMultipartForm(50 << 20); err != nil {
http.Error(w, "File too large", http.StatusRequestEntityTooLarge)
return
}
// CSRF validation for multipart form uploads.
// Mirrors: odoo/http.py validate_csrf()
sess := GetSession(r)
if sess != nil {
csrfToken := r.FormValue("csrf_token")
if csrfToken != sess.CSRFToken {
log.Printf("upload: CSRF token mismatch for uid=%d", sess.UID)
http.Error(w, "CSRF validation failed", http.StatusForbidden)
return
}
}
file, header, err := r.FormFile("ufile")
if err != nil {
http.Error(w, "No file uploaded", http.StatusBadRequest)

View File

@@ -195,6 +195,12 @@ func generateDefaultView(modelName, viewType string) string {
return generateDefaultPivotView(m)
case "graph":
return generateDefaultGraphView(m)
case "calendar":
return generateDefaultCalendarView(m)
case "activity":
return generateDefaultActivityView(m)
case "dashboard":
return generateDefaultDashboardView(m)
default:
return fmt.Sprintf("<%s><field name=\"id\"/></%s>", viewType, viewType)
}
@@ -530,6 +536,161 @@ func generateDefaultGraphView(m *orm.Model) string {
return fmt.Sprintf("<graph>\n %s\n</graph>", strings.Join(fields, "\n "))
}
// generateDefaultCalendarView creates a calendar view with auto-detected date fields.
// The OWL CalendarArchParser requires date_start; date_stop and color are optional.
// Mirrors: odoo/addons/web/static/src/views/calendar/calendar_arch_parser.js
func generateDefaultCalendarView(m *orm.Model) string {
// Auto-detect date_start field (priority order)
dateStart := ""
for _, candidate := range []string{"start", "date_start", "date_from", "date_order", "date_begin", "date"} {
if f := m.GetField(candidate); f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) {
dateStart = candidate
break
}
}
if dateStart == "" {
// Fallback: find any datetime/date field
for _, name := range sortedFieldNames(m) {
f := m.GetField(name)
if f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) && f.Name != "create_date" && f.Name != "write_date" {
dateStart = name
break
}
}
}
if dateStart == "" {
// No date field found — return minimal arch that won't crash
return `<calendar date_start="create_date"><field name="display_name"/></calendar>`
}
// Auto-detect date_stop field
dateStop := ""
for _, candidate := range []string{"stop", "date_stop", "date_to", "date_end"} {
if f := m.GetField(candidate); f != nil && (f.Type == orm.TypeDatetime || f.Type == orm.TypeDate) {
dateStop = candidate
break
}
}
// Auto-detect color field (M2O fields make good color discriminators)
colorField := ""
for _, candidate := range []string{"color", "user_id", "partner_id", "stage_id"} {
if f := m.GetField(candidate); f != nil {
colorField = candidate
break
}
}
// Auto-detect all_day field
allDay := ""
for _, candidate := range []string{"allday", "all_day"} {
if f := m.GetField(candidate); f != nil && f.Type == orm.TypeBoolean {
allDay = candidate
break
}
}
// Build attributes
attrs := fmt.Sprintf(`date_start="%s"`, dateStart)
if dateStop != "" {
attrs += fmt.Sprintf(` date_stop="%s"`, dateStop)
}
if colorField != "" {
attrs += fmt.Sprintf(` color="%s"`, colorField)
}
if allDay != "" {
attrs += fmt.Sprintf(` all_day="%s"`, allDay)
}
// Pick display fields for the calendar card
var fields []string
nameField := "display_name"
if f := m.GetField("name"); f != nil {
nameField = "name"
}
fields = append(fields, fmt.Sprintf(` <field name="%s"/>`, nameField))
if f := m.GetField("partner_id"); f != nil {
fields = append(fields, ` <field name="partner_id" avatar_field="avatar_128"/>`)
}
if f := m.GetField("user_id"); f != nil && colorField != "user_id" {
fields = append(fields, ` <field name="user_id"/>`)
}
return fmt.Sprintf("<calendar %s mode=\"month\">\n%s\n</calendar>",
attrs, strings.Join(fields, "\n"))
}
// generateDefaultActivityView creates a minimal activity view.
// Mirrors: odoo/addons/mail/static/src/views/web_activity/activity_arch_parser.js
func generateDefaultActivityView(m *orm.Model) string {
nameField := "display_name"
if f := m.GetField("name"); f != nil {
nameField = "name"
}
return fmt.Sprintf(`<activity string="Activities">
<templates>
<div t-name="activity-box">
<field name="%s"/>
</div>
</templates>
</activity>`, nameField)
}
// generateDefaultDashboardView creates a dashboard view with aggregate widgets.
// Mirrors: odoo/addons/board/static/src/board_view.js
func generateDefaultDashboardView(m *orm.Model) string {
var widgets []string
// Add aggregate widgets for numeric fields
for _, name := range sortedFieldNames(m) {
f := m.GetField(name)
if f == nil {
continue
}
if (f.Type == orm.TypeFloat || f.Type == orm.TypeInteger || f.Type == orm.TypeMonetary) &&
f.IsStored() && f.Name != "id" && f.Name != "sequence" &&
f.Name != "create_uid" && f.Name != "write_uid" && f.Name != "company_id" {
widgets = append(widgets, fmt.Sprintf(
` <aggregate name="%s" field="%s" string="%s"/>`,
f.Name, f.Name, f.String))
if len(widgets) >= 6 {
break
}
}
}
// Add a graph for the first groupable dimension
var graphField string
for _, name := range sortedFieldNames(m) {
f := m.GetField(name)
if f != nil && f.IsStored() && (f.Type == orm.TypeMany2one || f.Type == orm.TypeSelection) {
graphField = name
break
}
}
var buf strings.Builder
buf.WriteString("<dashboard>\n")
if len(widgets) > 0 {
buf.WriteString(" <group>\n")
for _, w := range widgets {
buf.WriteString(w + "\n")
}
buf.WriteString(" </group>\n")
}
if graphField != "" {
buf.WriteString(fmt.Sprintf(` <view type="graph">
<graph type="bar">
<field name="%s"/>
</graph>
</view>
`, graphField))
}
buf.WriteString("</dashboard>")
return buf.String()
}
// sortedFieldNames returns field names in alphabetical order for deterministic output.
func sortedFieldNames(m *orm.Model) []string {
fields := m.Fields()

View File

@@ -2,6 +2,7 @@ package server
import (
"fmt"
"strings"
"time"
"odoo-go/pkg/orm"
@@ -451,6 +452,110 @@ func (s *Server) handleReadGroup(rs *orm.Recordset, params CallKWParams) (interf
}
if params.Method == "web_read_group" {
// --- __fold support ---
// If the first groupby is a Many2one whose comodel has a "fold" field,
// add __fold to each group. Mirrors: odoo/addons/web/models/models.py
if len(groupby) > 0 {
fieldName := strings.SplitN(groupby[0], ":", 2)[0]
m := rs.ModelDef()
if m != nil {
f := m.GetField(fieldName)
if f != nil && f.Type == orm.TypeMany2one && f.Comodel != "" {
comodel := orm.Registry.Get(f.Comodel)
if comodel != nil && comodel.GetField("fold") != nil {
addFoldInfo(rs.Env(), f.Comodel, groupby[0], groups)
}
}
}
}
// --- __records for auto_unfold ---
autoUnfold := false
if v, ok := params.KW["auto_unfold"].(bool); ok {
autoUnfold = v
}
if autoUnfold {
unfoldReadSpec, _ := params.KW["unfold_read_specification"].(map[string]interface{})
unfoldLimit := defaultWebSearchLimit
if v, ok := params.KW["unfold_read_default_limit"].(float64); ok {
unfoldLimit = int(v)
}
// Parse original domain for combining with group domain
origDomain := parseDomain(params.Args)
if origDomain == nil {
if dr, ok := params.KW["domain"].([]interface{}); ok && len(dr) > 0 {
origDomain = parseDomain([]interface{}{dr})
}
}
modelName := rs.ModelDef().Name()
maxUnfolded := 10
unfolded := 0
for _, g := range groups {
if unfolded >= maxUnfolded {
break
}
gm := g.(map[string]interface{})
fold, _ := gm["__fold"].(bool)
count, _ := gm["__count"].(int64)
// Skip folded, empty, and groups with false/nil M2O value
// Mirrors: odoo/addons/web/models/models.py _open_groups() fold checks
if fold || count == 0 {
continue
}
// For M2O groupby: skip groups where the value is false (unset M2O)
if len(groupby) > 0 {
gbVal := gm[groupby[0]]
if gbVal == nil || gbVal == false {
continue
}
}
// Build combined domain: original + group extra domain
var combinedDomain orm.Domain
if origDomain != nil {
combinedDomain = append(combinedDomain, origDomain...)
}
if extraDom, ok := gm["__extra_domain"].([]interface{}); ok && len(extraDom) > 0 {
groupDomain := parseDomain([]interface{}{extraDom})
combinedDomain = append(combinedDomain, groupDomain...)
}
found, err := rs.Env().Model(modelName).Search(combinedDomain, orm.SearchOpts{Limit: unfoldLimit})
if err != nil || found.IsEmpty() {
gm["__records"] = []orm.Values{}
unfolded++
continue
}
fields := specToFields(unfoldReadSpec)
if len(fields) == 0 {
fields = []string{"id"}
}
hasID := false
for _, f := range fields {
if f == "id" {
hasID = true
break
}
}
if !hasID {
fields = append([]string{"id"}, fields...)
}
records, err := found.Read(fields)
if err != nil {
gm["__records"] = []orm.Values{}
unfolded++
continue
}
formatRecordsForWeb(rs.Env(), modelName, records, unfoldReadSpec)
gm["__records"] = records
unfolded++
}
}
// web_read_group: also get total group count (without limit/offset)
totalLen := len(results)
if opts.Limit > 0 || opts.Offset > 0 {
@@ -470,6 +575,203 @@ func (s *Server) handleReadGroup(rs *orm.Recordset, params CallKWParams) (interf
return groups, nil
}
// handleReadProgressBar returns per-group counts for a progress bar field.
// Mirrors: odoo/orm/models.py BaseModel._read_progress_bar()
//
// Called by the kanban view to render colored progress bars per column.
// Input (via KW):
//
// domain: search filter
// group_by: field to group columns by (e.g. "stage_id")
// progress_bar: {field: "kanban_state", colors: {"done": "success", ...}}
//
// Output:
//
// {groupByValue: {pbValue: count, ...}, ...}
//
// Where groupByValue is the raw DB value (integer ID for M2O, string for
// selection, "True"/"False" for boolean).
func (s *Server) handleReadProgressBar(rs *orm.Recordset, params CallKWParams) (interface{}, *RPCError) {
// Parse domain from KW
domain := parseDomain(params.Args)
if domain == nil {
if dr, ok := params.KW["domain"].([]interface{}); ok && len(dr) > 0 {
domain = parseDomain([]interface{}{dr})
}
}
// Parse group_by (single string)
groupBy := ""
if v, ok := params.KW["group_by"].(string); ok {
groupBy = v
}
// Parse progress_bar map
progressBar, _ := params.KW["progress_bar"].(map[string]interface{})
pbField, _ := progressBar["field"].(string)
if groupBy == "" || pbField == "" {
return map[string]interface{}{}, nil
}
// Use ReadGroup with two groupby levels: [groupBy, pbField]
results, err := rs.ReadGroup(domain, []string{groupBy, pbField}, []string{"__count"})
if err != nil {
return map[string]interface{}{}, nil
}
// Determine field types for key formatting
m := rs.ModelDef()
gbField := m.GetField(groupBy)
pbFieldDef := m.GetField(pbField)
// Build nested map: {groupByValue: {pbValue: count}}
data := make(map[string]interface{})
// Collect all known progress bar values (from colors) so we initialize zeros
pbColors, _ := progressBar["colors"].(map[string]interface{})
for _, r := range results {
// Format the group-by key
gbVal := r.GroupValues[groupBy]
gbKey := formatProgressBarKey(gbVal, gbField)
// Format the progress bar value
pbVal := r.GroupValues[pbField]
pbKey := formatProgressBarValue(pbVal, pbFieldDef)
// Initialize group entry with zero counts if first time
if _, exists := data[gbKey]; !exists {
entry := make(map[string]interface{})
for colorKey := range pbColors {
entry[colorKey] = 0
}
data[gbKey] = entry
}
// Add count
entry := data[gbKey].(map[string]interface{})
existing, _ := entry[pbKey].(int)
entry[pbKey] = existing + int(r.Count)
}
return data, nil
}
// formatProgressBarKey formats a group-by value as the string key expected
// by the frontend progress bar.
// - M2O: integer ID (as string)
// - Boolean: "True" / "False"
// - nil/false: "False"
// - Other: value as string
func formatProgressBarKey(val interface{}, f *orm.Field) string {
if val == nil || val == false {
return "False"
}
// M2O: ReadGroup resolves to [id, name] pair — use the id
if f != nil && f.Type == orm.TypeMany2one {
switch v := val.(type) {
case []interface{}:
if len(v) > 0 {
return fmt.Sprintf("%v", v[0])
}
return "False"
case int64:
return fmt.Sprintf("%d", v)
case float64:
return fmt.Sprintf("%d", int64(v))
case int:
return fmt.Sprintf("%d", v)
}
}
// Boolean
if f != nil && f.Type == orm.TypeBoolean {
switch v := val.(type) {
case bool:
if v {
return "True"
}
return "False"
}
}
return fmt.Sprintf("%v", val)
}
// formatProgressBarValue formats a progress bar field value as a string key.
// Selection fields use the raw value (e.g. "done", "blocked").
// Boolean fields use "True"/"False".
func formatProgressBarValue(val interface{}, f *orm.Field) string {
if val == nil || val == false {
return "False"
}
if f != nil && f.Type == orm.TypeBoolean {
switch v := val.(type) {
case bool:
if v {
return "True"
}
return "False"
}
}
return fmt.Sprintf("%v", val)
}
// addFoldInfo reads the "fold" boolean from the comodel records referenced
// by each group and sets __fold on the group maps accordingly.
func addFoldInfo(env *orm.Environment, comodel string, groupbySpec string, groups []interface{}) {
// Collect IDs from group values (M2O pairs like [id, name])
var ids []int64
for _, g := range groups {
gm := g.(map[string]interface{})
val := gm[groupbySpec]
if pair, ok := val.([]interface{}); ok && len(pair) >= 1 {
if id, ok := orm.ToRecordID(pair[0]); ok && id > 0 {
ids = append(ids, id)
}
}
}
if len(ids) == 0 {
// All groups have false/empty value — fold them by default
for _, g := range groups {
gm := g.(map[string]interface{})
gm["__fold"] = false
}
return
}
// Read fold values from comodel
rs := env.Model(comodel).Browse(ids...)
records, err := rs.Read([]string{"id", "fold"})
if err != nil {
return
}
// Build fold map
foldMap := make(map[int64]bool)
for _, rec := range records {
id, _ := orm.ToRecordID(rec["id"])
fold, _ := rec["fold"].(bool)
foldMap[id] = fold
}
// Apply to groups
for _, g := range groups {
gm := g.(map[string]interface{})
val := gm[groupbySpec]
if pair, ok := val.([]interface{}); ok && len(pair) >= 1 {
if id, ok := orm.ToRecordID(pair[0]); ok {
gm["__fold"] = foldMap[id]
}
} else {
// false/empty group value
gm["__fold"] = false
}
}
}
// formatDateFields converts date/datetime values to Odoo's expected string format.
func formatDateFields(model string, records []orm.Values) {
m := orm.Registry.Get(model)

View File

@@ -2,6 +2,7 @@ package server
import (
"bufio"
"context"
"embed"
"encoding/json"
"fmt"
@@ -73,6 +74,14 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
return
}
// Check if post-setup wizard is needed (first login, company not configured)
if s.isPostSetupNeeded() {
if sess := GetSession(r); sess != nil {
http.Redirect(w, r, "/web/setup/wizard", http.StatusFound)
return
}
}
// Check authentication
sess := GetSession(r)
if sess == nil {
@@ -141,7 +150,7 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
%s
<script>
var odoo = {
csrf_token: "dummy",
csrf_token: "%s",
debug: "assets",
__session_info__: %s,
reloadMenus: function() {
@@ -178,12 +187,18 @@ func (s *Server) handleWebClient(w http.ResponseWriter, r *http.Request) {
%s</head>
<body class="o_web_client">
</body>
</html>`, linkTags.String(), sessionInfoJSON, scriptTags.String())
</html>`, linkTags.String(), sess.CSRFToken, sessionInfoJSON, scriptTags.String())
}
// buildSessionInfo constructs the session_info JSON object expected by the webclient.
// Mirrors: odoo/addons/web/models/ir_http.py session_info()
func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
// Build allowed_company_ids from session (populated at login)
allowedIDs := sess.AllowedCompanyIDs
if len(allowedIDs) == 0 {
allowedIDs = []int64{sess.CompanyID}
}
return map[string]interface{}{
"session_id": sess.ID,
"uid": sess.UID,
@@ -194,7 +209,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
"user_context": map[string]interface{}{
"lang": "en_US",
"tz": "UTC",
"allowed_company_ids": []int64{sess.CompanyID},
"allowed_company_ids": allowedIDs,
},
"db": s.config.DBName,
"registry_hash": fmt.Sprintf("odoo-go-%d", time.Now().Unix()),
@@ -213,7 +228,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
"current_menu": 1,
"support_url": "",
"notification_type": "email",
"display_switch_company_menu": false,
"display_switch_company_menu": len(allowedIDs) > 1,
"test_mode": false,
"show_effect": true,
"currencies": map[string]interface{}{
@@ -226,20 +241,7 @@ func (s *Server) buildSessionInfo(sess *Session) map[string]interface{} {
"lang": "en_US",
"debug": "assets",
},
"user_companies": map[string]interface{}{
"current_company": sess.CompanyID,
"allowed_companies": map[string]interface{}{
fmt.Sprintf("%d", sess.CompanyID): map[string]interface{}{
"id": sess.CompanyID,
"name": "My Company",
"sequence": 10,
"child_ids": []int64{},
"parent_id": false,
"currency_id": 1,
},
},
"disallowed_ancestor_companies": map[string]interface{}{},
},
"user_companies": s.buildUserCompanies(sess.CompanyID, allowedIDs),
"user_settings": map[string]interface{}{
"id": 1,
"user_id": map[string]interface{}{"id": sess.UID, "display_name": sess.Login},
@@ -365,3 +367,105 @@ func (s *Server) handleTranslations(w http.ResponseWriter, r *http.Request) {
"multi_lang": multiLang,
})
}
// buildUserCompanies queries company data and builds the user_companies dict
// for the session_info response. Mirrors: odoo/addons/web/models/ir_http.py
func (s *Server) buildUserCompanies(currentCompanyID int64, allowedIDs []int64) map[string]interface{} {
allowedCompanies := make(map[string]interface{})
// Batch query all companies at once
rows, err := s.pool.Query(context.Background(),
`SELECT id, COALESCE(name, 'Company'), COALESCE(currency_id, 1)
FROM res_company WHERE id = ANY($1)`, allowedIDs)
if err == nil {
defer rows.Close()
for rows.Next() {
var cid, currencyID int64
var name string
if rows.Scan(&cid, &name, &currencyID) == nil {
allowedCompanies[fmt.Sprintf("%d", cid)] = map[string]interface{}{
"id": cid,
"name": name,
"sequence": 10,
"child_ids": []int64{},
"parent_id": false,
"currency_id": currencyID,
}
}
}
}
// Fallback for any IDs not found in DB
for _, cid := range allowedIDs {
key := fmt.Sprintf("%d", cid)
if _, exists := allowedCompanies[key]; !exists {
allowedCompanies[key] = map[string]interface{}{
"id": cid, "name": fmt.Sprintf("Company %d", cid),
"sequence": 10, "child_ids": []int64{}, "parent_id": false, "currency_id": int64(1),
}
}
}
return map[string]interface{}{
"current_company": currentCompanyID,
"allowed_companies": allowedCompanies,
"disallowed_ancestor_companies": map[string]interface{}{},
}
}
// handleSwitchCompany switches the active company for the current session.
func (s *Server) handleSwitchCompany(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
sess := GetSession(r)
if sess == nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: 100, Message: "Not authenticated"})
return
}
var req JSONRPCRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
s.writeJSONRPC(w, nil, nil, &RPCError{Code: -32700, Message: "Parse error"})
return
}
var params struct {
CompanyID int64 `json:"company_id"`
}
if err := json.Unmarshal(req.Params, &params); err != nil || params.CompanyID == 0 {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: -32602, Message: "Invalid company_id"})
return
}
// Validate company is in allowed list
allowed := false
for _, cid := range sess.AllowedCompanyIDs {
if cid == params.CompanyID {
allowed = true
break
}
}
if !allowed {
s.writeJSONRPC(w, req.ID, nil, &RPCError{Code: 403, Message: "Company not in allowed list"})
return
}
// Update session
sess.CompanyID = params.CompanyID
// Persist to DB
if s.sessions.pool != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.sessions.pool.Exec(ctx,
`UPDATE sessions SET company_id = $1 WHERE id = $2`, params.CompanyID, sess.ID)
}
s.writeJSONRPC(w, req.ID, map[string]interface{}{
"company_id": params.CompanyID,
"result": true,
}, nil)
}

160
pkg/service/automation.go Normal file
View File

@@ -0,0 +1,160 @@
package service
import (
"fmt"
"log"
"strings"
"github.com/jackc/pgx/v5"
"odoo-go/pkg/orm"
"odoo-go/pkg/tools"
)
// RunAutomatedActions checks and executes server actions triggered by Create/Write/Unlink.
// Called from the ORM after successful Create/Write/Unlink operations.
// Mirrors: odoo/addons/base_automation/models/base_automation.py
func RunAutomatedActions(env *orm.Environment, modelName, trigger string, recordIDs []int64) {
if len(recordIDs) == 0 {
return
}
// Look up the ir_model ID for this model
var modelID int64
err := env.Tx().QueryRow(env.Ctx(),
`SELECT id FROM ir_model WHERE model = $1`, modelName).Scan(&modelID)
if err != nil {
return // Model not in ir_model — no actions possible
}
// Find matching automated actions
rows, err := env.Tx().Query(env.Ctx(),
`SELECT id, state, COALESCE(update_field_id, ''), COALESCE(update_value, ''),
COALESCE(email_to, ''), COALESCE(email_subject, ''), COALESCE(email_body, ''),
COALESCE(filter_domain, '')
FROM ir_act_server
WHERE model_id = $1
AND active = true
AND trigger = $2
ORDER BY sequence, id`, modelID, trigger)
if err != nil {
log.Printf("automation: query error for %s/%s: %v", modelName, trigger, err)
return
}
defer rows.Close()
type action struct {
id int64
state string
updateField string
updateValue string
emailTo string
emailSubject string
emailBody string
filterDomain string
}
var actions []action
for rows.Next() {
var a action
if err := rows.Scan(&a.id, &a.state, &a.updateField, &a.updateValue,
&a.emailTo, &a.emailSubject, &a.emailBody, &a.filterDomain); err != nil {
continue
}
actions = append(actions, a)
}
if len(actions) == 0 {
return
}
for _, a := range actions {
switch a.state {
case "object_write":
executeObjectWrite(env, modelName, recordIDs, a.updateField, a.updateValue)
case "email":
executeEmailAction(env, modelName, recordIDs, a.emailTo, a.emailSubject, a.emailBody)
}
}
}
// executeObjectWrite updates a field on the triggered records.
func executeObjectWrite(env *orm.Environment, modelName string, recordIDs []int64, fieldName, value string) {
if fieldName == "" {
return
}
tableName := strings.ReplaceAll(modelName, ".", "_")
for _, id := range recordIDs {
_, err := env.Tx().Exec(env.Ctx(),
fmt.Sprintf(`UPDATE %s SET %s = $1 WHERE id = $2`,
pgx.Identifier{tableName}.Sanitize(),
pgx.Identifier{fieldName}.Sanitize()),
value, id)
if err != nil {
log.Printf("automation: object_write error %s.%s on %d: %v", modelName, fieldName, id, err)
}
}
}
// executeEmailAction sends an email for each triggered record.
func executeEmailAction(env *orm.Environment, modelName string, recordIDs []int64, emailToField, subject, bodyTemplate string) {
if emailToField == "" {
return
}
cfg := tools.LoadSMTPConfig()
if cfg.Host == "" {
return
}
tableName := strings.ReplaceAll(modelName, ".", "_")
for _, id := range recordIDs {
// Resolve email address from the record
var email string
err := env.Tx().QueryRow(env.Ctx(),
fmt.Sprintf(`SELECT COALESCE(%s, '') FROM %s WHERE id = $1`,
pgx.Identifier{emailToField}.Sanitize(),
pgx.Identifier{tableName}.Sanitize()),
id).Scan(&email)
if err != nil || email == "" {
continue
}
// Simple template: replace {{field}} with record values
body := bodyTemplate
if strings.Contains(body, "{{") {
body = resolveTemplate(env, tableName, id, body)
}
if err := tools.SendEmail(cfg, email, subject, body); err != nil {
log.Printf("automation: email error to %s for %s/%d: %v", email, modelName, id, err)
}
}
}
// resolveTemplate replaces {{field_name}} placeholders with actual record values.
func resolveTemplate(env *orm.Environment, tableName string, recordID int64, template string) string {
result := template
for {
start := strings.Index(result, "{{")
if start == -1 {
break
}
end := strings.Index(result[start:], "}}")
if end == -1 {
break
}
fieldName := strings.TrimSpace(result[start+2 : start+end])
var val string
err := env.Tx().QueryRow(env.Ctx(),
fmt.Sprintf(`SELECT COALESCE(CAST(%s AS TEXT), '') FROM %s WHERE id = $1`,
pgx.Identifier{fieldName}.Sanitize(),
pgx.Identifier{tableName}.Sanitize()),
recordID).Scan(&val)
if err != nil {
val = ""
}
result = result[:start] + val + result[start+end+2:]
}
return result
}

View File

@@ -2,57 +2,70 @@ package service
import (
"context"
"fmt"
"log"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"odoo-go/pkg/orm"
)
// CronJob defines a scheduled task.
type CronJob struct {
Name string
Interval time.Duration
Handler func(ctx context.Context, pool *pgxpool.Pool) error
running bool
const (
cronPollInterval = 60 * time.Second
maxFailureCount = 5
)
// cronJob holds a single scheduled action loaded from the ir_cron table.
type cronJob struct {
ID int64
Name string
ModelName string
MethodName string
UserID int64
IntervalNumber int
IntervalType string
NumberCall int
NextCall time.Time
}
// CronScheduler manages periodic jobs.
// CronScheduler polls ir_cron and executes ready jobs.
// Mirrors: odoo/addons/base/models/ir_cron.py IrCron._process_jobs()
type CronScheduler struct {
jobs []*CronJob
mu sync.Mutex
pool *pgxpool.Pool
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewCronScheduler creates a new scheduler.
func NewCronScheduler() *CronScheduler {
// NewCronScheduler creates a DB-driven cron scheduler.
func NewCronScheduler(pool *pgxpool.Pool) *CronScheduler {
ctx, cancel := context.WithCancel(context.Background())
return &CronScheduler{ctx: ctx, cancel: cancel}
return &CronScheduler{pool: pool, ctx: ctx, cancel: cancel}
}
// Register adds a job to the scheduler.
func (s *CronScheduler) Register(job *CronJob) {
s.mu.Lock()
defer s.mu.Unlock()
s.jobs = append(s.jobs, job)
// Start begins the polling loop in a background goroutine.
func (s *CronScheduler) Start() {
s.wg.Add(1)
go s.pollLoop()
log.Println("cron: scheduler started")
}
// Start begins running all registered jobs.
func (s *CronScheduler) Start(pool *pgxpool.Pool) {
for _, job := range s.jobs {
go s.runJob(job, pool)
}
log.Printf("cron: started %d jobs", len(s.jobs))
}
// Stop cancels all running jobs.
// Stop cancels the polling loop and waits for completion.
func (s *CronScheduler) Stop() {
s.cancel()
s.wg.Wait()
log.Println("cron: scheduler stopped")
}
func (s *CronScheduler) runJob(job *CronJob, pool *pgxpool.Pool) {
ticker := time.NewTicker(job.Interval)
func (s *CronScheduler) pollLoop() {
defer s.wg.Done()
// Run once immediately, then on ticker
s.processJobs()
ticker := time.NewTicker(cronPollInterval)
defer ticker.Stop()
for {
@@ -60,9 +73,200 @@ func (s *CronScheduler) runJob(job *CronJob, pool *pgxpool.Pool) {
case <-s.ctx.Done():
return
case <-ticker.C:
if err := job.Handler(s.ctx, pool); err != nil {
log.Printf("cron: %s error: %v", job.Name, err)
}
s.processJobs()
}
}
}
// processJobs queries all ready cron jobs and processes them one by one.
func (s *CronScheduler) processJobs() {
rows, err := s.pool.Query(s.ctx, `
SELECT id, name, model_name, method_name, user_id,
interval_number, interval_type, numbercall, nextcall
FROM ir_cron
WHERE active = true AND nextcall <= now()
ORDER BY priority, id
`)
if err != nil {
log.Printf("cron: query error: %v", err)
return
}
var jobs []cronJob
for rows.Next() {
var j cronJob
var modelName, methodName *string // nullable
if err := rows.Scan(&j.ID, &j.Name, &modelName, &methodName, &j.UserID,
&j.IntervalNumber, &j.IntervalType, &j.NumberCall, &j.NextCall); err != nil {
log.Printf("cron: scan error: %v", err)
continue
}
if modelName != nil {
j.ModelName = *modelName
}
if methodName != nil {
j.MethodName = *methodName
}
jobs = append(jobs, j)
}
rows.Close()
for _, job := range jobs {
s.processOneJob(job)
}
}
// processOneJob acquires a row-level lock and executes a single cron job.
func (s *CronScheduler) processOneJob(job cronJob) {
tx, err := s.pool.Begin(s.ctx)
if err != nil {
return
}
defer tx.Rollback(s.ctx)
// Try to acquire the job with FOR NO KEY UPDATE SKIP LOCKED
var lockedID int64
err = tx.QueryRow(s.ctx, `
SELECT id FROM ir_cron
WHERE id = $1 AND active = true AND nextcall <= now()
FOR NO KEY UPDATE SKIP LOCKED
`, job.ID).Scan(&lockedID)
if err != nil {
// Job already taken by another worker or not ready
return
}
log.Printf("cron: executing %q (id=%d)", job.Name, job.ID)
execErr := s.executeJob(job)
now := time.Now()
nextCall := calculateNextCall(now, job.IntervalNumber, job.IntervalType)
if execErr != nil {
log.Printf("cron: %q failed: %v", job.Name, execErr)
// Update failure count, set first_failure_date if not already set
if _, err := tx.Exec(s.ctx, `
UPDATE ir_cron SET
failure_count = failure_count + 1,
first_failure_date = COALESCE(first_failure_date, $1),
lastcall = $1,
nextcall = $2
WHERE id = $3
`, now, nextCall, job.ID); err != nil {
log.Printf("cron: failed to update failure count for %q: %v", job.Name, err)
}
// Deactivate if too many consecutive failures
if _, err := tx.Exec(s.ctx, `
UPDATE ir_cron SET active = false
WHERE id = $1 AND failure_count >= $2
`, job.ID, maxFailureCount); err != nil {
log.Printf("cron: failed to deactivate %q: %v", job.Name, err)
}
} else {
log.Printf("cron: %q completed successfully", job.Name)
if job.NumberCall > 0 {
// Finite run count: decrement
newNumberCall := job.NumberCall - 1
if newNumberCall <= 0 {
if _, err := tx.Exec(s.ctx, `
UPDATE ir_cron SET active = false, lastcall = $1, nextcall = $2,
failure_count = 0, first_failure_date = NULL, numbercall = 0
WHERE id = $3
`, now, nextCall, job.ID); err != nil {
log.Printf("cron: failed to update job %q: %v", job.Name, err)
}
} else {
if _, err := tx.Exec(s.ctx, `
UPDATE ir_cron SET lastcall = $1, nextcall = $2,
failure_count = 0, first_failure_date = NULL, numbercall = $3
WHERE id = $4
`, now, nextCall, newNumberCall, job.ID); err != nil {
log.Printf("cron: failed to update job %q: %v", job.Name, err)
}
}
} else {
// numbercall <= 0 means infinite runs
if _, err := tx.Exec(s.ctx, `
UPDATE ir_cron SET lastcall = $1, nextcall = $2,
failure_count = 0, first_failure_date = NULL
WHERE id = $3
`, now, nextCall, job.ID); err != nil {
log.Printf("cron: failed to update job %q: %v", job.Name, err)
}
}
}
if err := tx.Commit(s.ctx); err != nil {
log.Printf("cron: commit error for %q: %v", job.Name, err)
}
}
// executeJob looks up the target method in orm.Registry and calls it.
func (s *CronScheduler) executeJob(job cronJob) error {
if job.ModelName == "" || job.MethodName == "" {
return fmt.Errorf("cron %q: model_name or method_name not set", job.Name)
}
model := orm.Registry.Get(job.ModelName)
if model == nil {
return fmt.Errorf("cron %q: model %q not found", job.Name, job.ModelName)
}
if model.Methods == nil {
return fmt.Errorf("cron %q: model %q has no methods", job.Name, job.ModelName)
}
method, ok := model.Methods[job.MethodName]
if !ok {
return fmt.Errorf("cron %q: method %q not found on %q", job.Name, job.MethodName, job.ModelName)
}
// Create ORM environment for job execution
uid := job.UserID
if uid == 0 {
return fmt.Errorf("cron %q: user_id not set, refusing to run as admin", job.Name)
}
env, err := orm.NewEnvironment(s.ctx, orm.EnvConfig{
Pool: s.pool,
UID: uid,
Context: map[string]interface{}{
"lastcall": job.NextCall,
"cron_id": job.ID,
},
})
if err != nil {
return fmt.Errorf("cron %q: env error: %w", job.Name, err)
}
defer env.Close()
// Call the method on an empty recordset of the target model
_, err = method(env.Model(job.ModelName))
if err != nil {
env.Rollback()
return err
}
return env.Commit()
}
// calculateNextCall computes the next execution time based on interval.
// Mirrors: odoo/addons/base/models/ir_cron.py _intervalTypes
func calculateNextCall(from time.Time, number int, intervalType string) time.Time {
switch intervalType {
case "minutes":
return from.Add(time.Duration(number) * time.Minute)
case "hours":
return from.Add(time.Duration(number) * time.Hour)
case "days":
return from.AddDate(0, 0, number)
case "weeks":
return from.AddDate(0, 0, number*7)
case "months":
return from.AddDate(0, number, 0)
default:
return from.Add(time.Duration(number) * time.Hour)
}
}

View File

@@ -330,6 +330,8 @@ func SeedWithSetup(ctx context.Context, pool *pgxpool.Pool, cfg SetupConfig) err
VALUES (1, 1, true, 1, 1) ON CONFLICT (id) DO NOTHING`)
})
safeExec(ctx, tx, "base_groups", func() { seedBaseGroups(ctx, tx) })
safeExec(ctx, tx, "acl_rules", func() { seedACLRules(ctx, tx) })
safeExec(ctx, tx, "system_params", func() { seedSystemParams(ctx, tx) })
safeExec(ctx, tx, "languages", func() { seedLanguages(ctx, tx) })
safeExec(ctx, tx, "translations", func() { seedTranslations(ctx, tx) })
@@ -1676,3 +1678,136 @@ func generateUUID() string {
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
// seedBaseGroups creates the base security groups and their XML IDs.
// Mirrors: odoo/addons/base/security/base_groups.xml
func seedBaseGroups(ctx context.Context, tx pgx.Tx) {
log.Println("db: seeding base security groups...")
type groupDef struct {
id int64
name string
xmlID string
}
groups := []groupDef{
{1, "Internal User", "group_user"},
{2, "Settings", "group_system"},
{3, "Access Rights", "group_erp_manager"},
{4, "Allow Export", "group_allow_export"},
{5, "Portal", "group_portal"},
{6, "Public", "group_public"},
}
for _, g := range groups {
tx.Exec(ctx, `INSERT INTO res_groups (id, name)
VALUES ($1, $2) ON CONFLICT (id) DO NOTHING`, g.id, g.name)
tx.Exec(ctx, `INSERT INTO ir_model_data (module, name, model, res_id)
VALUES ('base', $1, 'res.groups', $2) ON CONFLICT DO NOTHING`, g.xmlID, g.id)
}
// Add admin user (uid=1) to all groups
for _, g := range groups {
tx.Exec(ctx, `INSERT INTO res_groups_res_users_rel (res_groups_id, res_users_id)
VALUES ($1, 1) ON CONFLICT DO NOTHING`, g.id)
}
}
// seedACLRules creates access control entries for ALL registered models.
// Categorizes models into security tiers and assigns appropriate permissions.
// Mirrors: odoo/addons/base/security/ir.model.access.csv + per-module CSVs
func seedACLRules(ctx context.Context, tx pgx.Tx) {
log.Println("db: seeding ACL rules for all models...")
// Resolve group IDs
var groupSystem, groupUser int64
err := tx.QueryRow(ctx,
`SELECT g.id FROM res_groups g
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
WHERE imd.module = 'base' AND imd.name = 'group_system'`).Scan(&groupSystem)
if err != nil {
log.Printf("db: cannot find group_system, skipping ACL seeding: %v", err)
return
}
err = tx.QueryRow(ctx,
`SELECT g.id FROM res_groups g
JOIN ir_model_data imd ON imd.res_id = g.id AND imd.model = 'res.groups'
WHERE imd.module = 'base' AND imd.name = 'group_user'`).Scan(&groupUser)
if err != nil {
log.Printf("db: cannot find group_user, skipping ACL seeding: %v", err)
return
}
// ── Security Tiers ──────────────────────────────────────────────
// Tier 1: System-only — only group_system gets full access
systemOnly := map[string]bool{
"ir.cron": true, "ir.rule": true, "ir.model.access": true,
}
// Tier 2: Admin-only — group_user=read, group_system=full
adminOnly := map[string]bool{
"ir.model": true, "ir.model.fields": true, "ir.model.data": true,
"ir.module.category": true, "ir.actions.server": true, "ir.sequence": true,
"ir.logging": true, "ir.config_parameter": true, "ir.default": true,
"ir.translation": true, "ir.actions.report": true, "report.paperformat": true,
"res.config.settings": true,
}
// Tier 3: Read-only for users — group_user=read, group_system=full
readOnly := map[string]bool{
"res.currency": true, "res.currency.rate": true,
"res.country": true, "res.country.state": true, "res.country.group": true,
"res.lang": true, "uom.category": true, "uom.uom": true,
"product.category": true, "product.removal": true,
"account.account.tag": true, "account.group": true,
"account.tax.group": true, "account.tax.repartition.line": true,
}
// Everything else → Tier 4: Standard user (group_user=full, group_system=full)
// Helper to insert an ACL rule
insertACL := func(modelID int64, modelName string, groupID int64, suffix string, read, write, create, unlink bool) {
aclName := "access_" + strings.ReplaceAll(modelName, ".", "_") + "_" + suffix
tx.Exec(ctx, `
INSERT INTO ir_model_access (name, model_id, group_id, perm_read, perm_write, perm_create, perm_unlink, active)
VALUES ($1, $2, $3, $4, $5, $6, $7, true)
ON CONFLICT DO NOTHING`,
aclName, modelID, groupID, read, write, create, unlink)
}
// Iterate ALL registered models
allModels := orm.Registry.Models()
seeded := 0
for _, m := range allModels {
modelName := m.Name()
if m.IsAbstract() {
continue // Abstract models have no table → no ACL needed
}
// Look up ir_model ID
var modelID int64
err := tx.QueryRow(ctx,
"SELECT id FROM ir_model WHERE model = $1", modelName).Scan(&modelID)
if err != nil {
continue // Not yet in ir_model — will be seeded on next restart
}
if systemOnly[modelName] {
// Tier 1: only group_system full access
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
} else if adminOnly[modelName] {
// Tier 2: group_user=read, group_system=full
insertACL(modelID, modelName, groupUser, "user_read", true, false, false, false)
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
} else if readOnly[modelName] {
// Tier 3: group_user=read, group_system=full
insertACL(modelID, modelName, groupUser, "user_read", true, false, false, false)
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
} else {
// Tier 4: group_user=full, group_system=full
insertACL(modelID, modelName, groupUser, "user", true, true, true, true)
insertACL(modelID, modelName, groupSystem, "system", true, true, true, true)
}
seeded++
}
log.Printf("db: seeded ACL rules for %d models", seeded)
}

255
pkg/service/fetchmail.go Normal file
View File

@@ -0,0 +1,255 @@
package service
import (
"context"
"fmt"
"io"
"log"
"os"
"strings"
"github.com/emersion/go-imap/v2"
"github.com/emersion/go-imap/v2/imapclient"
gomessage "github.com/emersion/go-message"
_ "github.com/emersion/go-message/charset"
"github.com/jackc/pgx/v5/pgxpool"
)
// FetchmailConfig holds IMAP server configuration.
type FetchmailConfig struct {
Host string
Port int
User string
Password string
UseTLS bool
Folder string
}
// LoadFetchmailConfig loads IMAP settings from environment variables.
func LoadFetchmailConfig() *FetchmailConfig {
cfg := &FetchmailConfig{
Port: 993,
UseTLS: true,
Folder: "INBOX",
}
cfg.Host = os.Getenv("IMAP_HOST")
cfg.User = os.Getenv("IMAP_USER")
cfg.Password = os.Getenv("IMAP_PASSWORD")
if v := os.Getenv("IMAP_FOLDER"); v != "" {
cfg.Folder = v
}
if os.Getenv("IMAP_TLS") == "false" {
cfg.UseTLS = false
if cfg.Port == 993 {
cfg.Port = 143
}
}
return cfg
}
// FetchAndProcessEmails connects to IMAP, fetches unseen emails, and creates
// mail.message records in the database. Matches emails to existing threads
// via In-Reply-To/References headers.
// Mirrors: odoo/addons/fetchmail/models/fetchmail.py fetch_mail()
func FetchAndProcessEmails(ctx context.Context, pool *pgxpool.Pool) error {
cfg := LoadFetchmailConfig()
if cfg.Host == "" {
return nil // IMAP not configured
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
var c *imapclient.Client
var err error
if cfg.UseTLS {
c, err = imapclient.DialTLS(addr, nil)
} else {
c, err = imapclient.DialInsecure(addr, nil)
}
if err != nil {
return fmt.Errorf("fetchmail: connect to %s: %w", addr, err)
}
defer c.Close()
if err := c.Login(cfg.User, cfg.Password).Wait(); err != nil {
return fmt.Errorf("fetchmail: login as %s: %w", cfg.User, err)
}
if _, err := c.Select(cfg.Folder, nil).Wait(); err != nil {
return fmt.Errorf("fetchmail: select %s: %w", cfg.Folder, err)
}
// Search unseen
criteria := &imap.SearchCriteria{
NotFlag: []imap.Flag{imap.FlagSeen},
}
searchData, err := c.Search(criteria, nil).Wait()
if err != nil {
return fmt.Errorf("fetchmail: search: %w", err)
}
seqSet := searchData.All
if seqSet == nil {
return nil
}
// Fetch envelope + body
fetchOpts := &imap.FetchOptions{
Envelope: true,
BodySection: []*imap.FetchItemBodySection{{}},
}
msgs, err := c.Fetch(seqSet, fetchOpts).Collect()
if err != nil {
return fmt.Errorf("fetchmail: fetch: %w", err)
}
var processed int
for _, msg := range msgs {
if err := processOneEmail(ctx, pool, msg); err != nil {
log.Printf("fetchmail: process error: %v", err)
continue
}
processed++
}
// Mark as seen
if processed > 0 {
storeFlags := &imap.StoreFlags{
Op: imap.StoreFlagsAdd,
Flags: []imap.Flag{imap.FlagSeen},
}
if _, err := c.Store(seqSet, storeFlags, nil).Collect(); err != nil {
log.Printf("fetchmail: mark seen error: %v", err)
}
log.Printf("fetchmail: processed %d new emails", processed)
}
return nil
}
func processOneEmail(ctx context.Context, pool *pgxpool.Pool, buf *imapclient.FetchMessageBuffer) error {
env := buf.Envelope
if env == nil {
return fmt.Errorf("no envelope")
}
subject := env.Subject
messageID := env.MessageID
inReplyTo := env.InReplyTo
date := env.Date
var fromEmail, fromName string
if len(env.From) > 0 {
fromEmail = fmt.Sprintf("%s@%s", env.From[0].Mailbox, env.From[0].Host)
fromName = env.From[0].Name
}
// Extract body from body section
var bodyText string
bodyBytes := buf.FindBodySection(&imap.FetchItemBodySection{})
if bodyBytes != nil {
bodyText = parseEmailBody(bodyBytes)
}
if bodyText == "" {
bodyText = "(no body)"
}
// Find author partner by email
var authorID int64
pool.QueryRow(ctx,
`SELECT id FROM res_partner WHERE LOWER(email) = LOWER($1) LIMIT 1`, fromEmail,
).Scan(&authorID)
// Thread matching via In-Reply-To
var parentModel string
var parentResID int64
if len(inReplyTo) > 0 && inReplyTo[0] != "" {
pool.QueryRow(ctx,
`SELECT model, res_id FROM mail_message
WHERE message_id = $1 AND model IS NOT NULL AND res_id IS NOT NULL
LIMIT 1`, inReplyTo[0],
).Scan(&parentModel, &parentResID)
}
// Fallback: match by subject
if parentModel == "" && subject != "" {
clean := subject
for _, prefix := range []string{"Re: ", "RE: ", "Fwd: ", "FW: ", "AW: "} {
clean = strings.TrimPrefix(clean, prefix)
}
pool.QueryRow(ctx,
`SELECT model, res_id FROM mail_message
WHERE subject = $1 AND model IS NOT NULL AND res_id IS NOT NULL
ORDER BY id DESC LIMIT 1`, clean,
).Scan(&parentModel, &parentResID)
}
_, err := pool.Exec(ctx,
`INSERT INTO mail_message
(subject, body, message_type, email_from, author_id, model, res_id,
date, message_id, create_uid, write_uid, create_date, write_date)
VALUES ($1, $2, 'email', $3, $4, $5, $6, $7, $8, 1, 1, NOW(), NOW())`,
subject, bodyText,
fmt.Sprintf("%s <%s>", fromName, fromEmail),
nilIfZero(authorID),
nilIfEmpty(parentModel),
nilIfZero(parentResID),
date,
messageID,
)
return err
}
func parseEmailBody(raw []byte) string {
entity, err := gomessage.Read(strings.NewReader(string(raw)))
if err != nil {
return string(raw) // fallback: raw text
}
if mr := entity.MultipartReader(); mr != nil {
var htmlBody, textBody string
for {
part, err := mr.NextPart()
if err != nil {
break
}
ct, _, _ := part.Header.ContentType()
body, _ := io.ReadAll(part.Body)
switch {
case strings.HasPrefix(ct, "text/html"):
htmlBody = string(body)
case strings.HasPrefix(ct, "text/plain"):
textBody = string(body)
}
}
if htmlBody != "" {
return htmlBody
}
return textBody
}
// Single part
body, _ := io.ReadAll(entity.Body)
return string(body)
}
func nilIfZero(v int64) interface{} {
if v == 0 {
return nil
}
return v
}
func nilIfEmpty(v string) interface{} {
if v == "" {
return nil
}
return v
}
// RegisterFetchmailCron ensures the message_id column exists for thread matching.
func RegisterFetchmailCron(ctx context.Context, pool *pgxpool.Pool) {
pool.Exec(ctx, `ALTER TABLE mail_message ADD COLUMN IF NOT EXISTS message_id VARCHAR(255)`)
pool.Exec(ctx, `CREATE INDEX IF NOT EXISTS idx_mail_message_message_id ON mail_message(message_id)`)
log.Println("fetchmail: ready (IMAP config via IMAP_HOST/IMAP_USER/IMAP_PASSWORD env vars)")
}

View File

@@ -1,6 +1,7 @@
package tools
import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
@@ -64,8 +65,12 @@ func SendEmail(cfg *SMTPConfig, to, subject, body string) error {
return nil
}
// Sanitize headers to prevent injection via \r\n
sanitize := func(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=utf-8\r\n\r\n%s",
cfg.From, to, subject, body)
sanitize(cfg.From), sanitize(to), sanitize(subject), body)
auth := smtp.PlainAuth("", cfg.User, cfg.Password, cfg.Host)
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
@@ -83,12 +88,17 @@ func SendEmailWithAttachments(cfg *SMTPConfig, to []string, subject, bodyHTML st
}
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
boundary := "==odoo-go-boundary-42=="
b := make([]byte, 16)
rand.Read(b)
boundary := fmt.Sprintf("==odoo-go-%x==", b)
sanitize := func(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
var msg strings.Builder
msg.WriteString(fmt.Sprintf("From: %s\r\n", cfg.From))
msg.WriteString(fmt.Sprintf("To: %s\r\n", strings.Join(to, ", ")))
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
msg.WriteString(fmt.Sprintf("From: %s\r\n", sanitize(cfg.From)))
msg.WriteString(fmt.Sprintf("To: %s\r\n", sanitize(strings.Join(to, ", "))))
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", sanitize(subject)))
msg.WriteString("MIME-Version: 1.0\r\n")
if len(attachments) > 0 {